{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bfc94e6c-5590-43df-9221-243224dbb60a",
   "metadata": {},
   "source": [
    "|任务名称|难度|\n",
    "|--|--|\n",
    "|任务1：数据集读取|低、1|\n",
    "|任务2：文本数据分析|低、1|\n",
    "|任务3：文本相似度（统计特征）|中、2|\n",
    "|任务4：文本相似度（词向量与句子编码）|高、3|\n",
    "|任务5：文本匹配模型（LSTM孪生网络）|中、2|\n",
    "|任务6：文本匹配模型（Sentence-BERT模型）|高、3|\n",
    "|任务7：文本匹配模型（SimCSE模型）|高、3|"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6976d2d3-bd58-4359-8c07-48eaffa8dff4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\env\\Anaconda3\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "D:\\env\\Anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll\n",
      "D:\\env\\Anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas.XWYDX2IKJW2NMTWSFYNGFUWKQU3LYTCZ.gfortran-win_amd64.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import jieba\n",
    "from tqdm import tqdm\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.feature_extraction.text import TfidfTransformer, TfidfVectorizer\n",
    "from scipy import spatial\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "from gensim.models import Word2Vec,TfidfModel\n",
    "# from gensim.summarization import bm25\n",
    "from sklearn.decomposition import TruncatedSVD\n",
    "from collections import Counter\n",
    "import math\n",
    "\n",
    "import paddle\n",
    "import paddle.nn as nn\n",
    "from paddle.io import Dataset, DataLoader\n",
    "from paddle.nn import functional\n",
    "from paddlenlp.transformers import AutoTokenizer,AutoModelForSequenceClassification,AutoModel\n",
    "\n",
    "\n",
    "import logging\n",
    "\n",
    "logging.disable(logging.DEBUG)  # 关闭DEBUG日志的打印\n",
    "logging.disable(logging.WARNING)  # 关闭WARNING日志的打印\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99441b90-72eb-4fc8-97a9-ece1812ba4c9",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 任务1：数据集读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "52e4d6f0-68b2-4261-81d9-36bb6a0f8a75",
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "\n",
    "def load_lcqmc():\n",
    "    '''LCQMC文本匹配数据集\n",
    "    '''\n",
    "    train = pd.read_csv('https://mirror.coggle.club/dataset/LCQMC.train.data.zip', sep='\\t', names=['query1', 'query2', 'label'])\n",
    "\n",
    "    valid = pd.read_csv('https://mirror.coggle.club/dataset/LCQMC.valid.data.zip', sep='\\t', names=['query1', 'query2', 'label'])\n",
    "\n",
    "    test = pd.read_csv('https://mirror.coggle.club/dataset/LCQMC.test.data.zip', sep='\\t', names=['query1', 'query2', 'label'])\n",
    "\n",
    "    return train, valid, test\n",
    "# 读取数据集\n",
    "train_df, valid_df, test_df = load_lcqmc()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e88176fa-2fd9-46d2-8f67-89103e41c892",
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "train_df.to_csv('data/train.csv', index=False)\n",
    "test_df.to_csv('data/test.csv', index=False)\n",
    "valid_df.to_csv('data/val.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3a5c9a5-ed1a-4003-8325-73a19ffaf7d7",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## 任务2：文本数据分析"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "e1454a2c-c11e-4f40-84e6-4b3b91d28a7c",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "相似语句统计\n",
      "字数差标准差: 1.523462829258909\n",
      "字数差均值: 1.4892476222090725\n",
      "不相似语句统计\n",
      "字数差标准差: 2.8916345878531464\n",
      "字数差均值: 2.6640949377195784\n"
     ]
    }
   ],
   "source": [
    "# 统计字数差\n",
    "train_df = pd.read_csv('data/train.csv')\n",
    "train_df['len1'] = train_df['query1'].map(lambda x: len(x))\n",
    "train_df['len2'] = train_df['query2'].map(lambda x: len(x))\n",
    "train_df['len_diff'] = abs(train_df['len1'] - train_df['len2'])\n",
    "\n",
    "print('相似语句统计')\n",
    "print('字数差标准差:', train_df[train_df['label'] == 1]['len_diff'].std())\n",
    "print('字数差均值:', train_df[train_df['label'] == 1]['len_diff'].mean())\n",
    "print('不相似语句统计')\n",
    "print('字数差标准差:', train_df[train_df['label'] == 0]['len_diff'].std())\n",
    "print('字数差均值:', train_df[train_df['label'] == 0]['len_diff'].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "2bc5c4aa-7c5b-4480-a928-f1af345a640f",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "共有 245951 句话\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████| 245951/245951 [00:09<00:00, 26178.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "共有 39657 个单词\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████| 245951/245951 [00:00<00:00, 569167.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "共有 5039 个字符\n"
     ]
    }
   ],
   "source": [
    "# 统计字符个数\n",
    "text_list = []\n",
    "text_list += train_df['query1'].tolist()\n",
    "text_list += train_df['query2'].tolist()\n",
    "text_list = list(set(text_list))\n",
    "\n",
    "print('共有', len(text_list), '句话')\n",
    "\n",
    "word_list = []\n",
    "for text in tqdm(text_list):\n",
    "    cut = jieba.lcut(text, cut_all=False)\n",
    "    word_list += cut\n",
    "word_list = list(set(word_list))\n",
    "print('共有', len(word_list), '个单词')\n",
    "\n",
    "char_list = []\n",
    "for text in tqdm(text_list):\n",
    "    for c in text:\n",
    "        char_list.append(c)\n",
    "char_list = list(set(char_list))\n",
    "print('共有', len(char_list), '个字符')\n",
    "        \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5320e298-b131-472e-9794-5180383d85c0",
   "metadata": {},
   "source": [
    "## 任务3：文本相似度（统计特征）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cdd6eec",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "#### 统计长度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "61bfe30d-9242-4acd-b257-2af35bc73a1b",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD8CAYAAAB3u9PLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAXLklEQVR4nO3df4zV9Z3v8edbBsWtBREGgwx22Mj1B9aKTAx7abbbsr2waheb4O5s6jq95YZer9cfjckWa5Pr/mEizc3akq7mGuk6sl6VsLWSWrtrwcZswuKC26sgcpkuVEcRRlBWb4Nl2Pf943zGHsb5cWYY5gfzfCQn53ve5/v5ns9H2vOaz+f7PedEZiJJ0hkj3QFJ0uhgIEiSAANBklQYCJIkwECQJBUGgiQJqDEQIuLciNgQEa9FxK6I+L2IOC8inouIPeV+atX+d0VEW0TsjoglVfUFEfFKeW5NRESpnxURT5b61ohoHPKRSpL6VOsM4XvATzPzEuAzwC5gFbApM+cCm8pjIuIyoBmYBywFHoiICeU4DwIrgbnltrTUVwDvZuZFwP3A6pMclyRpgPoNhIiYDPw+sBYgM3+Tme8By4DWslsrcH3ZXgY8kZkfZuZeoA24OiJmApMzc0tWPg33aLc2XcfaACzumj1IkoZHXQ37/C7QAfxNRHwG2A7cDpyfmfsBMnN/RMwo+88C/qmqfXupHSvb3etdbd4ox+qMiCPANOCd3jo1ffr0bGxsrKH7kqQu27dvfycz63t6rpZAqAOuAm7NzK0R8T3K8lAvevrLPvuo99XmxANHrKSy5MSFF17Itm3b+uq3JKmbiPhVb8/Vcg6hHWjPzK3l8QYqAXGgLANR7g9W7T+7qn0D8FapN/RQP6FNRNQBU4DD3TuSmQ9lZlNmNtXX9xhwkqRB6jcQMvNt4I2IuLiUFgOvAhuBllJrAZ4u2xuB5nLl0BwqJ49fLMtL70fEwnJ+4KZubbqOtRzYnH7rniQNq1qWjABuBR6LiDOBfwX+M5UwWR8RK4DXgRsAMnNnRKynEhqdwC2Zebwc52bgEeBs4Nlyg8oJ63UR0UZlZtB8kuOSJA1QjNU/xJuamtJzCNL4duzYMdrb2zl69OhId2XUmTRpEg0NDUycOPGEekRsz8ymntrUOkOQpFGnvb2dT37ykzQ2NuKV6r+VmRw6dIj29nbmzJlTczu/ukLSmHX06FGmTZtmGHQTEUybNm3AMycDQdKYZhj0bDD/XQwESRLgOQRJp5HGVc8M6fH23Xdtv/ucc845fPDBB70fY98+rrvuOnbs2FHz6371q1/luuuuY/ny5T0+v3fvXpqbmzl8+DBXXXUV69at48wzz6z5+L1xhgBwz5SR7oEk1eyb3/wm3/jGN9izZw9Tp05l7dq1Q3JcA0GShsAHH3zA4sWLueqqq/j0pz/N008//dFznZ2dtLS0cMUVV7B8+XJ+/etfA7B9+3Y+97nPsWDBApYsWcL+/fv7fZ3MZPPmzR/NHlpaWvjRj340JGMwECRpCEyaNImnnnqKl156ieeff54777yTrs957d69m5UrV/Lyyy8zefJkHnjgAY4dO8att97Khg0b2L59O1/72te4++67+32dQ4cOce6551JXV1nxb2ho4M033xySMXgOQZKGQGbyrW99ixdeeIEzzjiDN998kwMHDgAwe/ZsFi1aBMCNN97ImjVrWLp0KTt27OCLX/wiAMePH2fmzJk1vU53Q3WllYEgSUPgscceo6Ojg+3btzNx4kQaGxs/+hxA9zfsiCAzmTdvHlu2bBnQ60yfPp333nuPzs5O6urqaG9v54ILLhiSMbhkJElD4MiRI8yYMYOJEyfy/PPP86tf/fZbpl9//fWP3vgff/xxPvvZz3LxxRfT0dHxUf3YsWPs3Lmz39eJCD7/+c+zYcMGAFpbW1m2bNmQjMEZgqTTRi2XiZ4qX/nKV/jSl75EU1MTV155JZdccslHz1166aW0trby9a9/nblz53LzzTdz5plnsmHDBm677TaOHDlCZ2cnd9xxB/Pmzev3tVavXk1zczPf/va3mT9/PitWrBiSMfjldlC57PSeI0NzLEnDZteuXVx66aUj3Y1Rq6f/Pn19uZ1LRpIkwCUjSRq1vvzlL7N3794TaqtXr2bJkiWn5PUMBEkapZ566qlhfT2XjCRJgIEgSSoMBEkSYCBIkgpPKks6fQz1V9nX8Pmkkfg9hO9///t897vf5Ze//CUdHR1Mnz695mP3xRmCJI0xixYt4mc/+xmf+tSnhvS4BoIkDYHh+j0EgPnz59PY2DjkYzAQJGkIDNfvIZxKnkOQpCEwXL+HcCoZCJI0BIbr9xBOJZeMJGkIDNfvIZxKNc0QImIf8D5wHOjMzKaIOA94EmgE9gF/kpnvlv3vAlaU/W/LzL8v9QXAI8DZwE+A2zMzI+Is4FFgAXAI+NPM3DckI5Q0fozg19gP5+8hrFmzhu985zu8/fbbXHHFFVxzzTU8/PDDJz2Gmn4PoQRCU2a+U1X7DnA4M++LiFXA1Mz8ZkRcBjwOXA1cAPwM+A+ZeTwiXgRuB/6JSiCsycxnI+K/AVdk5n+NiGbgy5n5p331yd9DkOTvIfRtOH8PYRnQWrZbgeur6k9k5oeZuRdoA66OiJnA5MzckpUUerRbm65jbQAWx1D9arQkqSa1nlRO4B8iIoH/lZkPAedn5n6AzNwfETPKvrOozAC6tJfasbLdvd7V5o1yrM6IOAJMA95Bksap0fp7CIsy863ypv9cRLzWx749/WWffdT7anPigSNWAisBLrzwwr57LElj3Kj8PYTMfKvcHwSeonJ+4EBZBqLcHyy7twOzq5o3AG+VekMP9RPaREQdMAU43EM/HsrMpsxsqq+vr6Xrkk5zY/V34U+1wfx36TcQIuITEfHJrm3gPwE7gI1AS9mtBej6nPZGoDkizoqIOcBc4MWyvPR+RCws5wdu6tam61jLgc3pv7KkfkyaNIlDhw4ZCt1kJocOHWLSpEkDalfLktH5wFPlHG8d8L8z86cR8c/A+ohYAbwO3FA6sjMi1gOvAp3ALZl5vBzrZn572emz5QawFlgXEW1UZgbNAxqFpHGpoaGB9vZ2Ojo6Rroro86kSZNoaGjof8cqNV12Ohp52akkDdypuuxUknQaMRAkSYCBIEkqDARJEmAgSJIKA0GSBBgIkqTCQJAkAQaCJKkwECRJgIEgSSoMBEkSYCBIkgoDQZIEGAiSpMJAkCQBBoIkqTAQJEmAgSBJKgwESRJgIEiSCgNBkgQYCJKkwkCQJAEGgiSpMBAkSYCBIEkqag6EiJgQEf8SET8uj8+LiOciYk+5n1q1710R0RYRuyNiSVV9QUS8Up5bExFR6mdFxJOlvjUiGodwjJKkGgxkhnA7sKvq8SpgU2bOBTaVx0TEZUAzMA9YCjwQERNKmweBlcDcclta6iuAdzPzIuB+YPWgRiNJGrSaAiEiGoBrgYerysuA1rLdClxfVX8iMz/MzL1AG3B1RMwEJmfmlsxM4NFubbqOtQFY3DV7kCQNj1pnCN8F/gL496ra+Zm5H6Dczyj1WcAbVfu1l9qsst29fkKbzOwEjgDTah2EJOnk9RsIEXEdcDAzt9d4zJ7+ss8+6n216d6XlRGxLSK2dXR01NgdSVItapkhLAL+OCL2AU8AX4iIvwUOlGUgyv3Bsn87MLuqfQPwVqk39FA/oU1E1AFTgMPdO5KZD2VmU2Y21dfX1zRASVJt+g2EzLwrMxsys5HKyeLNmXkjsBFoKbu1AE+X7Y1Ac7lyaA6Vk8cvlmWl9yNiYTk/cFO3Nl3HWl5e42MzBEnSqVN3Em3vA9ZHxArgdeAGgMzcGRHrgVeBTuCWzDxe2twMPAKcDTxbbgBrgXUR0UZlZtB8Ev2SJA3CgAIhM38O/LxsHwIW97LfvcC9PdS3AZf3UD9KCRRJ0sjwk8qSJMBAkCQVBoIkCTAQJEmFgSBJAgwESVJhIEiSAANBklQYCJIkwECQJBUGgiQJMBAkSYWBIEkCDARJUmEgSJIAA0GSVBgIkiTAQJAkFQaCJAkwECRJhYEgSQIMBElSYSBIkgADQZJUGAiSJMBAkCQV4zcQ7pky0j2QpFFl/AaCJOkE/QZCREyKiBcj4v9ExM6I+MtSPy8inouIPeV+alWbuyKiLSJ2R8SSqvqCiHilPLcmIqLUz4qIJ0t9a0Q0noKxSpL6UMsM4UPgC5n5GeBKYGlELARWAZsycy6wqTwmIi4DmoF5wFLggYiYUI71ILASmFtuS0t9BfBuZl4E3A+sPvmhSZIGot9AyIoPysOJ5ZbAMqC11FuB68v2MuCJzPwwM/cCbcDVETETmJyZWzIzgUe7tek61gZgcdfsQZI0PGo6hxAREyLiF8BB4LnM3Aqcn5n7Acr9jLL7LOCNqubtpTarbHevn9AmMzuBI8C0HvqxMiK2RcS2jo6OmgYoSapNTYGQmccz80qggcpf+5f3sXtPf9lnH/W+2nTvx0OZ2ZSZTfX19f30WpI0EAO6yigz3wN+TmXt/0BZBqLcHyy7tQOzq5o1AG+VekMP9RPaREQdMAU4PJC+SZJOTi1XGdVHxLll+2zgD4HXgI1AS9mtBXi6bG8EmsuVQ3OonDx+sSwrvR8RC8v5gZu6tek61nJgcznPIEkaJnU17DMTaC1XCp0BrM/MH0fEFmB9RKwAXgduAMjMnRGxHngV6ARuyczj5Vg3A48AZwPPlhvAWmBdRLRRmRk0D8XgJEm16zcQMvNlYH4P9UPA4l7a3Avc20N9G/Cx8w+ZeZQSKJKkkeEnlSVJgIEgSSoMBEkSYCBIkgoDQZIEGAiSpMJAkCQBBoIkqTAQJEmAgSBJKsZ9IDSuemakuyBJo8K4DwRJUoWBIEkCDARJUmEgSJIAA0GSVBgIkiTAQJAkFQaCJAkwECRJhYEgSQIMBElSYSBIkgADQZJUGAiSJMBAkCQVBoIkCaghECJidkQ8HxG7ImJnRNxe6udFxHMRsafcT61qc1dEtEXE7ohYUlVfEBGvlOfWRESU+lkR8WSpb42IxlMwVklSH2qZIXQCd2bmpcBC4JaIuAxYBWzKzLnApvKY8lwzMA9YCjwQERPKsR4EVgJzy21pqa8A3s3Mi4D7gdVDMDZJ0gD0GwiZuT8zXyrb7wO7gFnAMqC17NYKXF+2lwFPZOaHmbkXaAOujoiZwOTM3JKZCTzarU3XsTYAi7tmD5Kk4TGgcwhlKWc+sBU4PzP3QyU0gBllt1nAG1XN2kttVtnuXj+hTWZ2AkeAaQPpmyTp5NQcCBFxDvB3wB2Z+W997dpDLfuo99Wmex9WRsS2iNjW0dHRX5clSQNQUyBExEQqYfBYZv6wlA+UZSDK/cFSbwdmVzVvAN4q9YYe6ie0iYg6YApwuHs/MvOhzGzKzKb6+vpaui5JqlEtVxkFsBbYlZl/VfXURqClbLcAT1fVm8uVQ3OonDx+sSwrvR8RC8sxb+rWputYy4HN5TyDJGmY1NWwzyLgz4FXIuIXpfYt4D5gfUSsAF4HbgDIzJ0RsR54lcoVSrdk5vHS7mbgEeBs4Nlyg0rgrIuINiozg+aTG5YkaaD6DYTM/Ed6XuMHWNxLm3uBe3uobwMu76F+lBIoI6lx1TPsu+/ake6GJI0IP6ksSQIMBElSYSBIkgADQZJUGAiSJGCcBkLjqmdGuguSNOqMy0CQJH2cgSBJAgwESVJhIEiSAANBklQYCJIkwECQJBUGgiQJMBAkSYWBIEkCDARJUmEgSJIAA0GSVBgIkiTAQJAkFQZCD/y9BEnjkYEgSQIMhN7dM2WkeyBJw8pAkCQBBoIkqTAQJElADYEQET+IiIMRsaOqdl5EPBcRe8r91Krn7oqItojYHRFLquoLIuKV8tyaiIhSPysiniz1rRHROMRjlCTVoJYZwiPA0m61VcCmzJwLbCqPiYjLgGZgXmnzQERMKG0eBFYCc8ut65grgHcz8yLgfmD1YAcjSRq8fgMhM18ADncrLwNay3YrcH1V/YnM/DAz9wJtwNURMROYnJlbMjOBR7u16TrWBmBx1+xBkjR8BnsO4fzM3A9Q7meU+izgjar92kttVtnuXj+hTWZ2AkeAaYPslyRpkIb6pHJPf9lnH/W+2nz84BErI2JbRGzr6OgYZBclST0ZbCAcKMtAlPuDpd4OzK7arwF4q9Qbeqif0CYi6oApfHyJCoDMfCgzmzKzqb6+fpBdlyT1ZLCBsBFoKdstwNNV9eZy5dAcKiePXyzLSu9HxMJyfuCmbm26jrUc2FzOM0iShlFdfztExOPAHwDTI6Id+B/AfcD6iFgBvA7cAJCZOyNiPfAq0AnckpnHy6FupnLF0tnAs+UGsBZYFxFtVGYGzUMyMknSgPQbCJn5Z708tbiX/e8F7u2hvg24vIf6UUqgjEaNq55h333XjnQ3JOmU85PKkiTAQJAkFQaCJAkwECRJhYEgSQIMBElSYSBIkgADQZJUGAiSJMBAkCQVBoIkCTAQanPPFKDyvUaSdLoyECRJgIEgSSoMBEkSYCBIkgoDQZIEGAiSpMJAkCQBBoIkqTAQJEmAgSBJKgwESRJgIJw0v99I0unCQJAkAQaCJKkwEAarfCV2NZePJI1loyYQImJpROyOiLaIWDXS/ZGk8WZUBEJETAD+Gvgj4DLgzyLispHt1QD0MFvoUj1rcAYhaTQbFYEAXA20Zea/ZuZvgCeAZSPcp8HpIxyqdYWDISFptBgtgTALeKPqcXupjV3VwdDbdjf9zSacbUg6lSIzR7oPRMQNwJLM/C/l8Z8DV2fmrd32WwmsLA8vBnafxMtOB945ifajleMaWxzX2HI6jOtTmVnf0xN1w92TXrQDs6seNwBvdd8pMx8CHhqKF4yIbZnZNBTHGk0c19jiuMaW03VcXUbLktE/A3MjYk5EnAk0AxtHuE+SNK6MihlCZnZGxH8H/h6YAPwgM3eOcLckaVwZFYEAkJk/AX4yjC85JEtPo5DjGlsc19hyuo4LGCUnlSVJI2+0nEOQJI2wcRcIp8tXZETE7Ih4PiJ2RcTOiLi91M+LiOciYk+5nzrSfR2MiJgQEf8SET8uj8f8uCLi3IjYEBGvlX+33ztNxvWN8r/BHRHxeERMGqvjiogfRMTBiNhRVet1LBFxV3kv2R0RS0am10NnXAXCmP+KjBN1Andm5qXAQuCWMpZVwKbMnAtsKo/HotuBXVWPT4dxfQ/4aWZeAnyGyvjG9LgiYhZwG9CUmZdTuSikmbE7rkeApd1qPY6l/P+tGZhX2jxQ3mPGrHEVCJxGX5GRmfsz86Wy/T6VN5dZVMbTWnZrBa4fkQ6ehIhoAK4FHq4qj+lxRcRk4PeBtQCZ+ZvMfI8xPq6iDjg7IuqA36HyGaIxOa7MfAE43K3c21iWAU9k5oeZuRdoo/IeM2aNt0A4/b4iA4iIRmA+sBU4PzP3QyU0gBkj2LXB+i7wF8C/V9XG+rh+F+gA/qYshT0cEZ9gjI8rM98E/ifwOrAfOJKZ/8AYH1c3vY3ltHs/GW+BED3UxvRlVhFxDvB3wB2Z+W8j3Z+TFRHXAQczc/tI92WI1QFXAQ9m5nzg/zF2llF6VdbTlwFzgAuAT0TEjSPbq2Fz2r2fjLdAqOkrMsaKiJhIJQwey8wflvKBiJhZnp8JHByp/g3SIuCPI2IflSW9L0TE3zL2x9UOtGfm1vJ4A5WAGOvj+kNgb2Z2ZOYx4IfAf2Tsj6tab2M5rd5PYPwFwmnzFRkREVTWo3dl5l9VPbURaCnbLcDTw923k5GZd2VmQ2Y2Uvn32ZyZNzL2x/U28EZEXFxKi4FXGePjorJUtDAifqf8b3IxlfNZY31c1Xoby0agOSLOiog5wFzgxRHo39DJzHF1A64B/i/wS+Duke7PSYzjs1Smpy8Dvyi3a4BpVK6E2FPuzxvpvp7EGP8A+HHZHvPjAq4EtpV/sx8BU0+Tcf0l8BqwA1gHnDVWxwU8TuVcyDEqM4AVfY0FuLu8l+wG/mik+3+yNz+pLEkCxt+SkSSpFwaCJAkwECRJhYEgSQIMBElSYSBIkgADQZJUGAiSJAD+P7/s3jAAgc06AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def get_y(x, count):\n",
    "    '''\n",
    "        根据x求len_diff列表\n",
    "    '''\n",
    "    result = []\n",
    "    for k in x:\n",
    "        if k in count.keys():\n",
    "            result.append(count[k])\n",
    "        else:\n",
    "            result.append(0)\n",
    "    return result\n",
    "# 获取长度差\n",
    "train_df = pd.read_csv('data/train.csv')\n",
    "train_df['len1'] = train_df['query1'].map(lambda x: len(x))\n",
    "train_df['len2'] = train_df['query2'].map(lambda x: len(x))\n",
    "train_df['len_diff'] = abs(train_df['len1'] - train_df['len2'])\n",
    "train_df = train_df.sort_values(by='len_diff')\n",
    "# 获取个值\n",
    "x = train_df['len_diff'].unique()\n",
    "count_0 = train_df[train_df['label'] == 0]['len_diff'].value_counts().to_dict()\n",
    "count_1 = train_df[train_df['label'] == 1]['len_diff'].value_counts().to_dict()\n",
    "y_0 = get_y(x, count_0)\n",
    "y_1 = get_y(x, count_1)\n",
    "# 画图\n",
    "width = 0.3\n",
    "plt.bar(x, y_0, width=0.3, label='label_0')\n",
    "plt.bar(x+width, y_1, width=0.3, label='label_1')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2b19154",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "#### 统计词语个数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "b91889cc",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD4CAYAAADsKpHdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAXGUlEQVR4nO3dcZBV5Znn8e+jEHGjICJYhGbTuLBGIYrSZUyZmiTLZmAnU4NW4dipZGw3pNiyjJtJpWpF88e4f1AlW7s6obJaxYas6DoqxYyRmqzjGDTlboWFNI6rAkOkAsEWFlowjFYKFfLsH/dtvbSX7tvNhdu3+/upunXPfc55z31fb8mvz3vOPTcyE0mSzml2ByRJI4OBIEkCDARJUmEgSJIAA0GSVIxrdgeG65JLLsn29vZmd0OSWsq2bdveysyptda1bCC0t7fT3d3d7G5IUkuJiN+cap1TRpIkwECQJBUGgiQJaOFzCJL0wQcf0NPTw7Fjx5rdlRFnwoQJtLW1MX78+LrbGAiSWlZPTw8XXngh7e3tRESzuzNiZCaHDx+mp6eHWbNm1d3OKSNJLevYsWNMmTLFMOgnIpgyZcqQj5wMBEktzTCobTj/XQwESRLgOQRJo0j7ip82dH977/tqQ/c30nmEMBT3TvroIUnABRdcMOD6vXv3Mm/evCHt87bbbmPDhg2nXL9nzx4+97nPMWfOHG655Rbef//9Ie3/VAwESWoxd911F9/97nd5/fXXmTx5MmvXrm3Ifg0ESWqAd999l4ULF3Lttdfy2c9+lqeffvrDdcePH6erq4urrrqKpUuX8rvf/Q6Abdu28cUvfpEFCxawaNEiDhw4MOj7ZCbPP/88S5cuBaCrq4uf/OQnDRmDgSBJDTBhwgSeeuopXnrpJV544QW+973v0feb9bt27WL58uW88sorTJw4kQcffJAPPviAO++8kw0bNrBt2za++c1v8v3vf3/Q9zl8+DAXXXQR48ZVTgG3tbXx5ptvNmQMdZ1UjoiLgB8B84AEvgnsAp4E2oG9wJ9m5ttl+7uBZcAJ4N9n5rOlvgB4GDgf+J/AdzIzI+I84BFgAXAYuCUz9zZgfJJ0VmQm99xzDy+++CLnnHMOb775JgcPHgRg5syZ3HDDDQB84xvfYPXq1SxevJjXXnuNr3zlKwCcOHGC6dOn1/U+/TXq0tt6rzL6AfB3mbk0Ij4B/DPgHmBTZt4XESuAFcBdEXEl0AnMBT4F/Cwi/mVmngAeApYD/4dKICwGnqESHm9n5uyI6ARWAbc0ZISSdBY89thj9Pb2sm3bNsaPH097e/uHXwzr/w92RJCZzJ07l82bNw/pfS655BJ++9vfcvz4ccaNG0dPTw+f+tSnGjKGQQMhIiYCfwDcBpCZ7wPvR8QS4Etls3XAz4G7gCXAE5n5HrAnInYD10XEXmBiZm4u+30EuJFKICwB7i372gD8MCIia0WhJJ1CMy8TPXr0KNOmTWP8+PG88MIL/OY3H/3swL59+9i8eTOf//znefzxx/nCF77A5ZdfTm9v74f1Dz74gF/96lfMnTt3wPeJCL785S+zYcMGOjs7WbduHUuWLGnIGOo5h3AZ0Av894j4h4j4UUR8Erg0Mw8AlOdpZfsZwBtV7XtKbUZZ7l8/qU1mHgeOAlP6dyQilkdEd0R09/b21jlESTrzvv71r9Pd3U1HRwePPfYYn/nMZz5cd8UVV7Bu3Tquuuoqjhw5wu23384nPvEJNmzYwF133cXVV1/N/Pnz+cUvflHXe61atYr777+f2bNnc/jwYZYtW9aQMdQzZTQOuBa4MzO3RMQPqEwPnUqtyawcoD5Qm5MLmWuANQAdHR0ePUhqunfffReoTOWcavpnx44dNevz58/nxRdf/Fj94YcfHvA9L7vsMrZu3Tq0jtahniOEHqAnM7eU1xuoBMTBiJgOUJ4PVW0/s6p9G7C/1Ntq1E9qExHjgEnAkaEORpI0fIMGQmb+P+CNiLi8lBYCO4CNQFepdQF9F91uBDoj4ryImAXMAbaWaaV3IuL6qJxhubVfm759LQWe9/yBpLHupptuYv78+Sc9nn322TP2fvVeZXQn8Fi5wujXwL+lEibrI2IZsA+4GSAzt0fEeiqhcRy4o1xhBHA7H112+kx5AKwFHi0noI9QuUpJksa0p5566qy+X12BkJkvAx01Vi08xfYrgZU16t1UvsvQv36MEiiSpObwm8qSJMBAkCQV/h6CpNGj0bemv/doY/c3wnmEIEmnoRm/h/DDH/6Q2bNnExG89dZbQ9r3QAwESWoxN9xwAz/72c/49Kc/3dD9GgiS1ABn6/cQAK655hra29sbPgYDQZIa4Gz9HsKZ5EllSWqAs/V7CGeSgSBJDXC2fg/hTDIQJI0eTbxM9Gz9HsKZ5DkESWqAs/l7CKtXr6atrY2enh6uuuoqvvWtbzVkDNGqNxXt6OjI7u7us/um1V96GWNfWJFGop07d3LFFVc0uxsjVq3/PhGxLTNr3ZvOIwRJUoXnECRphLrpppvYs2fPSbVVq1axaNGiM/J+BoKklpaZH7uKZ7Q4nd9DGM7pAKeMJLWsCRMmcPjw4WH94zeaZSaHDx9mwoQJQ2rnEYKkltV3pU1vb2+zuzLiTJgwgba2tsE3rGIgSGpZ48ePZ9asWc3uxqjhlJEkCTAQJEmFgSBJAgwESVJhIEiSAANBklTUFQgRsTciXo2IlyOiu9QujojnIuL18jy5avu7I2J3ROyKiEVV9QVlP7sjYnWUrxdGxHkR8WSpb4mI9gaPU5I0iKEcIXw5M+dX3SVvBbApM+cAm8prIuJKoBOYCywGHoyIc0ubh4DlwJzyWFzqy4C3M3M28ACwavhDkiQNx+lMGS0B1pXldcCNVfUnMvO9zNwD7Aaui4jpwMTM3JyV75k/0q9N3742AAtjtN6cRJJGqHoDIYG/j4htEbG81C7NzAMA5Xlaqc8A3qhq21NqM8py//pJbTLzOHAUmNK/ExGxPCK6I6Lbr6pLUmPVe+uKGzJzf0RMA56LiH8cYNtaf9nnAPWB2pxcyFwDrIHKD+QM3GVJ0lDUdYSQmfvL8yHgKeA64GCZBqI8Hyqb9wAzq5q3AftLva1G/aQ2ETEOmAQcGfpwJEnDNWggRMQnI+LCvmXgD4HXgI1AV9msC3i6LG8EOsuVQ7OonDzeWqaV3omI68v5gVv7tenb11Lg+fR+tpJ0VtUzZXQp8FQ5xzsO+KvM/LuI+CWwPiKWAfuAmwEyc3tErAd2AMeBOzLzRNnX7cDDwPnAM+UBsBZ4NCJ2Uzky6GzA2CRJQzBoIGTmr4Gra9QPAwtP0WYlsLJGvRuYV6N+jBIokqTm8JvKkiTAQJAkFQaCJAkwECRJhYEgSQIMBElSYSBIkgADQZJUGAiSJMBAkCQVBoIkCTAQJEmFgSBJAgwESVJhIEiSAANBklQYCJIkwECQJBUGgiQJMBAkSYWBIEkCDARJUmEgSJIAA0GSVNQdCBFxbkT8Q0T8bXl9cUQ8FxGvl+fJVdveHRG7I2JXRCyqqi+IiFfLutUREaV+XkQ8WepbIqK9gWOUJNVhKEcI3wF2Vr1eAWzKzDnApvKaiLgS6ATmAouBByPi3NLmIWA5MKc8Fpf6MuDtzJwNPACsGtZoJEnDVlcgREQb8FXgR1XlJcC6srwOuLGq/kRmvpeZe4DdwHURMR2YmJmbMzOBR/q16dvXBmBh39GDJOnsqPcI4S+B/wD8vqp2aWYeACjP00p9BvBG1XY9pTajLPevn9QmM48DR4Ep/TsREcsjojsiunt7e+vsuiSpHoMGQkT8MXAoM7fVuc9af9nnAPWB2pxcyFyTmR2Z2TF16tQ6uyNJqse4Ora5AfiTiPgjYAIwMSL+B3AwIqZn5oEyHXSobN8DzKxq3wbsL/W2GvXqNj0RMQ6YBBwZ5pgkScMw6BFCZt6dmW2Z2U7lZPHzmfkNYCPQVTbrAp4uyxuBznLl0CwqJ4+3lmmldyLi+nJ+4NZ+bfr2tbS8x8eOECRJZ049Rwinch+wPiKWAfuAmwEyc3tErAd2AMeBOzLzRGlzO/AwcD7wTHkArAUejYjdVI4MOk+jX5KkYRhSIGTmz4Gfl+XDwMJTbLcSWFmj3g3Mq1E/RgkUSVJz+E1lSRJgIEiSCgNBkgQYCJKkwkCQJAEGgiSpMBAkSYCBIEkqTuebyqPDvZP6vT7anH5IUpN5hCBJAgwESVJhIEiSAANBklQYCJIkwECQJBUGgiQJMBAkSYWBIEkCDARJUmEgSJIAA0GSVHhzu0G0r/jph8t7JzSxI5J0hnmEIEkCDARJUmEgSJKAOgIhIiZExNaI+L8RsT0i/mOpXxwRz0XE6+V5clWbuyNid0TsiohFVfUFEfFqWbc6IqLUz4uIJ0t9S0S0n4GxSpIGUM8RwnvAv8rMq4H5wOKIuB5YAWzKzDnApvKaiLgS6ATmAouBByPi3LKvh4DlwJzyWFzqy4C3M3M28ACw6vSHJkkaikEDISveLS/Hl0cCS4B1pb4OuLEsLwGeyMz3MnMPsBu4LiKmAxMzc3NmJvBIvzZ9+9oALOw7epAknR11nUOIiHMj4mXgEPBcZm4BLs3MAwDleVrZfAbwRlXznlKbUZb7109qk5nHgaPAlBr9WB4R3RHR3dvbW9cAJUn1qSsQMvNEZs4H2qj8tT9vgM1r/WWfA9QHatO/H2sysyMzO6ZOnTpIryVJQzGkq4wy87fAz6nM/R8s00CU50Nlsx5gZlWzNmB/qbfVqJ/UJiLGAZOAI0PpmyTp9NRzldHUiLioLJ8P/GvgH4GNQFfZrAt4uixvBDrLlUOzqJw83lqmld6JiOvL+YFb+7Xp29dS4PlynkGSdJbUc+uK6cC6cqXQOcD6zPzbiNgMrI+IZcA+4GaAzNweEeuBHcBx4I7MPFH2dTvwMHA+8Ex5AKwFHo2I3VSODDobMThJUv0GDYTMfAW4pkb9MLDwFG1WAitr1LuBj51/yMxjlECRJDWH31SWJAEGgiSpMBAkSYCBIEkqDARJEmAgSJIKA0GSBBgIkqTCQJAkAQaCJKkwECRJQH03t9Nw3Dupavlo8/ohSXXyCEGSBBgIkqTCQJAkAQaCJKkwECRJwBi9yqh9xU8/XN47oYkdkaQRxCMESRJgIEiSCgNBkgQYCJKkwkCQJAEGgiSpGDQQImJmRLwQETsjYntEfKfUL46I5yLi9fI8uarN3RGxOyJ2RcSiqvqCiHi1rFsdEVHq50XEk6W+JSLaz8BYJUkDqOcI4Tjwvcy8ArgeuCMirgRWAJsycw6wqbymrOsE5gKLgQcj4tyyr4eA5cCc8lhc6suAtzNzNvAAsKoBY5MkDcGggZCZBzLzpbL8DrATmAEsAdaVzdYBN5blJcATmfleZu4BdgPXRcR0YGJmbs7MBB7p16ZvXxuAhX1HD5Kks2NI5xDKVM41wBbg0sw8AJXQAKaVzWYAb1Q16ym1GWW5f/2kNpl5HDgKTKnx/ssjojsiunt7e4fSdUnSIOq+dUVEXAD8NfDnmflPA/wBX2tFDlAfqM3Jhcw1wBqAjo6Oj61vNm+JIamV1XWEEBHjqYTBY5n5N6V8sEwDUZ4PlXoPMLOqeRuwv9TbatRPahMR44BJwJGhDkaSNHz1XGUUwFpgZ2beX7VqI9BVlruAp6vqneXKoVlUTh5vLdNK70TE9WWft/Zr07evpcDz5TyDJOksqWfK6Abgz4BXI+LlUrsHuA9YHxHLgH3AzQCZuT0i1gM7qFyhdEdmnijtbgceBs4HnikPqATOoxGxm8qRQefpDUuSNFSDBkJm/m9qz/EDLDxFm5XAyhr1bmBejfoxSqBIkprDbypLkgADQZJUGAiSJMBAkCQVBoIkCTAQJEmFgSBJAgwESVJhIEiSAANBklQYCJIkwECQJBUGgiQJMBAkSYWBIEkCDARJUmEgSJIAA0GSVBgIkiTAQJAkFQaCJAkwECRJhYEgSQJgXLM7MBa1r/jph8t77/tqE3siSR8Z9AghIn4cEYci4rWq2sUR8VxEvF6eJ1etuzsidkfErohYVFVfEBGvlnWrIyJK/byIeLLUt0REe4PHKEmqQz1TRg8Di/vVVgCbMnMOsKm8JiKuBDqBuaXNgxFxbmnzELAcmFMefftcBrydmbOBB4BVwx2MJGn4Bg2EzHwRONKvvARYV5bXATdW1Z/IzPcycw+wG7guIqYDEzNzc2Ym8Ei/Nn372gAs7Dt6kCSdPcM9qXxpZh4AKM/TSn0G8EbVdj2lNqMs96+f1CYzjwNHgSm13jQilkdEd0R09/b2DrPrkqRaGn2VUa2/7HOA+kBtPl7MXJOZHZnZMXXq1GF2UZJUy3AD4WCZBqI8Hyr1HmBm1XZtwP5Sb6tRP6lNRIwDJvHxKSpJ0hk23EDYCHSV5S7g6ap6Z7lyaBaVk8dby7TSOxFxfTk/cGu/Nn37Wgo8X84zjA33Tjr5IUlNMuj3ECLiceBLwCUR0QP8BXAfsD4ilgH7gJsBMnN7RKwHdgDHgTsy80TZ1e1Urlg6H3imPADWAo9GxG4qRwadDRmZJGlIBg2EzPzaKVYtPMX2K4GVNerdwLwa9WOUQJEkNY+3rpAkAQaCJKkwECRJgIEgSSoMBEkSYCBIkgoDQZIE+AM5I54/piPpbPEIQZIEGAiSpMJAkCQBBoIkqTAQJEmAgSBJKgwESRLg9xBaS/Uvqt17tHn9kDQqeYQgSQI8QhhV/FazpNPhEYIkCTAQJEmFgSBJAgwESVLhSeUxyJPPkmoxEEYrv7MgaYicMpIkASPoCCEiFgM/AM4FfpSZ9zW5S8LpJWksGRGBEBHnAv8V+ArQA/wyIjZm5o7m9mwMqJ5agmFPL1UHBwwcHoaMNDKNiEAArgN2Z+avASLiCWAJYCCMJEM5L9Ggcxj1hoeBJJ2+yMxm94GIWAoszsxvldd/BnwuM7/db7vlwPLy8nJg12m+9SXAW6e5j5HGMbWO0TguxzTyfTozp9ZaMVKOEKJG7WNJlZlrgDUNe9OI7szsaNT+RgLH1DpG47gcU2sbKVcZ9QAzq163Afub1BdJGpNGSiD8EpgTEbMi4hNAJ7CxyX2SpDFlREwZZebxiPg28CyVy05/nJnbz8JbN2z6aQRxTK1jNI7LMbWwEXFSWZLUfCNlykiS1GQGgiQJGKOBEBGLI2JXROyOiBXN7k+jRMTeiHg1Il6OiO5m92c4IuLHEXEoIl6rql0cEc9FxOvleXIz+zhUpxjTvRHxZvmsXo6IP2pmH4cqImZGxAsRsTMitkfEd0q9ZT+rAcbU0p/VUIy5cwjlNhm/ouo2GcDXRsNtMiJiL9CRmS37JZqI+APgXeCRzJxXav8JOJKZ95UAn5yZdzWzn0NxijHdC7ybmf+5mX0broiYDkzPzJci4kJgG3AjcBst+lkNMKY/pYU/q6EYi0cIH94mIzPfB/puk6ERIDNfBI70Ky8B1pXldVT+J20ZpxhTS8vMA5n5Ull+B9gJzKCFP6sBxjRmjMVAmAG8UfW6h9HzoSfw9xGxrdzmY7S4NDMPQOV/WmBak/vTKN+OiFfKlFLLTK30FxHtwDXAFkbJZ9VvTDBKPqvBjMVAqOs2GS3qhsy8Fvg3wB1lqkIj00PAvwDmAweA/9LU3gxTRFwA/DXw55n5T83uTyPUGNOo+KzqMRYDYdTeJiMz95fnQ8BTVKbHRoODZX63b573UJP7c9oy82BmnsjM3wP/jRb8rCJiPJV/OB/LzL8p5Zb+rGqNaTR8VvUai4EwKm+TERGfLCfCiIhPAn8IvDZwq5axEegqy13A003sS0P0/aNZ3ESLfVYREcBaYGdm3l+1qmU/q1ONqdU/q6EYc1cZAZTLxv6Sj26TsbK5PTp9EXEZlaMCqNyS5K9acVwR8TjwJSq3HD4I/AXwE2A98M+BfcDNmdkyJ2lPMaYvUZmCSGAv8O/65t5bQUR8AfhfwKvA70v5Hipz7i35WQ0wpq/Rwp/VUIzJQJAkfdxYnDKSJNVgIEiSAANBklQYCJIkwECQJBUGgiQJMBAkScX/B9+t52xQqjdLAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 统计词语个数\n",
    "train_df = pd.read_csv('data/train.csv')\n",
    "# 分词\n",
    "train_df['words_1'] = train_df['query1'].map(lambda x: ' '.join(jieba.lcut(x, cut_all=False)))\n",
    "train_df['words_2'] = train_df['query2'].map(lambda x: ' '.join(jieba.lcut(x, cut_all=False)))\n",
    "# 统计词个数\n",
    "train_df['words_len_1'] = train_df['words_1'].map(lambda x: len(x.split(' ')))\n",
    "train_df['words_len_2'] = train_df['words_2'].map(lambda x: len(x.split(' ')))\n",
    "train_df['words_len_diff'] = abs(train_df['words_len_1'] - train_df['words_len_2'])\n",
    "train_df = train_df.sort_values(by='words_len_diff')\n",
    "# 获取个值\n",
    "x = train_df['words_len_diff'].unique()\n",
    "count_0 = train_df[train_df['label'] == 0]['words_len_diff'].value_counts().to_dict()\n",
    "count_1 = train_df[train_df['label'] == 1]['words_len_diff'].value_counts().to_dict()\n",
    "y_0 = get_y(x, count_0)\n",
    "y_1 = get_y(x, count_1)\n",
    "# 画图\n",
    "width = 0.3\n",
    "plt.bar(x, y_0, width=0.3, label='label_0')\n",
    "plt.bar(x+width, y_1, width=0.3, label='label_1')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42358410",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "#### 统计词差异"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "f32e67ca",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD4CAYAAADsKpHdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAYYklEQVR4nO3df2xU553v8fcn2I3RphACJiI2rVFhW36UkGJRKio1Ke3ibVeX5IoortrGVam85ZIq7WZ1Q9I/Sq+EFK5ukyu2m0h0ieJkaRLkloLasLtJoMqtykJNbsqPsDS+C00MFjgmoUQVFNPv/WOeIYMZ7LE99hjP5yUdzZnvOc+Z55xEfHzOc+aMIgIzM7PrSt0BMzMbHRwIZmYGOBDMzCxxIJiZGeBAMDOzpKLUHRisKVOmRF1dXam7YWZ2Tdm3b9/bEVGdb9k1Gwh1dXW0tbWVuhtmZtcUSb+/2jJfMjIzM8CBYGZmSb+BIKlK0l5Jv5V0SNL3U32tpOOSXkvTF3LaPCSpXdIRScty6gslHUjLNkhSql8v6flU3yOpbhj21czM+lDIGMJ54LMR8Z6kSuBXknakZY9FxP/KXVnSHKARmAvcArwk6S8j4iLwBNAM/DvwAtAA7ABWAu9ExExJjcB64J6h756ZjWUXLlygo6ODc+fOlboro05VVRW1tbVUVlYW3KbfQIjMw47eS28r09TXA5CWA89FxHngqKR2YJGkY8CEiNgNIOlp4E4ygbAcWJvatwI/lKTwg5bMrA8dHR188IMfpK6ujnTBwYCIoLu7m46ODmbMmFFwu4LGECSNk/QacAp4MSL2pEX3Sdov6UlJk1KtBngrp3lHqtWk+d71y9pERA9wBpicpx/NktoktXV1dRXSdTMbw86dO8fkyZMdBr1IYvLkyQM+cyooECLiYkQsAGrJ/LU/j8zln48AC4BO4AfZvuTbRB/1vtr07sfGiKiPiPrq6ry30ZpZmXEY5DeY4zKgu4wi4l3gl0BDRJxMQfFn4EfAorRaBzA9p1ktcCLVa/PUL2sjqQKYCJweSN/MzGxo+h1DkFQNXIiIdyWNBz4HrJc0LSI602p3AQfT/Hbgx5IeJTOoPAvYGxEXJZ2VtBjYA9wL/ENOmyZgN7AC2OnxAzMbqLo1vyjq9o498sWibm+0K+QMYRqwS9J+4DdkxhB+DvzPdAvpfuAO4DsAEXEI2AK8DvwLsDrdYQSwCvgnoB34f2QGlAE2AZPTAPTfAWuKsXPDYu3EzGRmBtxwww19Lj927Bjz5s0b0Da/9rWv0draetXlR48e5ZOf/CSzZs3innvu4U9/+tOAtn81hdxltB+4LU/9q320WQesy1NvA644MhFxDri7v76YmRk8+OCDfOc736GxsZFvfvObbNq0iVWrVg15u/6msplZEbz33nssXbqUT3ziE3z84x9n27Ztl5b19PTQ1NTE/PnzWbFiBX/84x8B2LdvH5/5zGdYuHAhy5Yto7Oz82qbvyQi2LlzJytWrACgqamJn/3sZ0XZBweCmVkRVFVVsXXrVl599VV27drFAw88QHYo9MiRIzQ3N7N//34mTJjA448/zoULF/jWt75Fa2sr+/bt4+tf/zrf/e53+/2c7u5ubrzxRioqMhd4amtrOX78eFH24Zp92qmZ2WgSETz88MO88sorXHfddRw/fpyTJ08CMH36dJYsWQLAV77yFTZs2EBDQwMHDx7k85//PAAXL15k2rRpBX1Ob8W69daBYGZWBJs3b6arq4t9+/ZRWVlJXV3dpS+G9f4HWxIRwdy5c9m9e/eAPmfKlCm8++679PT0UFFRQUdHB7fccktR9sGBYGZjRilvEz1z5gxTp06lsrKSXbt28fvfv/+zA2+++Sa7d+/mU5/6FM8++yyf/vSn+ehHP0pXV9el+oULF/jd737H3Llz+/wcSdxxxx20trbS2NhIS0sLy5cvL8o+eAzBzKwIvvzlL9PW1kZ9fT2bN2/mYx/72KVls2fPpqWlhfnz53P69GlWrVrFBz7wAVpbW3nwwQe59dZbWbBgAb/+9a8L+qz169fz6KOPMnPmTLq7u1m5cmVR9kHX6ve/6uvroyS/mJb9DsLaMyP/2WZ2mcOHDzN79uxSd2PUynd8JO2LiPp86/sMwczMAI8hmJmNWnfddRdHjx69rLZ+/XqWLVt2lRZD40AwMxultm7dOqKf50tGZmYGOBCuzg+xM7My40AwMzPAYwhmNpYU+6y+zG4v9xmCmdkQlOL3EH74wx8yc+ZMJPH2228PaNt9cSCYmV1jlixZwksvvcSHP/zhom7XgWBmVgQj9XsIALfddht1dXVF3wcHgplZEYzU7yEMJw8qm5kVwUj9HsJwciCYmRXBSP0ewnByIJjZ2FHC20RH6vcQhlO/YwiSqiTtlfRbSYckfT/Vb5L0oqQ30uuknDYPSWqXdETSspz6QkkH0rINSrEp6XpJz6f6Hkl1w7CvZmbDZiR/D2HDhg3U1tbS0dHB/Pnz+cY3vlGUfej39xDSP9p/ERHvSaoEfgXcD/xX4HREPCJpDTApIh6UNAd4FlgE3AK8BPxlRFyUtDe1/XfgBWBDROyQ9N+A+RHxTUmNwF0RcU9f/Rr230O42u8e+PcQzEYN/x5C34r+ewiR8V56W5mmAJYDLaneAtyZ5pcDz0XE+Yg4CrQDiyRNAyZExO7IpNDTvdpkt9UKLM2ePZiZ2cgoaAxB0jhgHzAT+MeI2CPp5ojoBIiITklT0+o1ZM4AsjpS7UKa713PtnkrbatH0hlgMnDZV/AkNQPNAB/60IcK3Uczs2vSqPw9hIi4CCyQdCOwVVJf38PO95d99FHvq03vfmwENkLmklFffTaz8hARV9zFM1YM5fcQBvPzyAP6YlpEvAv8EmgATqbLQKTXU2m1DmB6TrNa4ESq1+apX9ZGUgUwETg9kL6ZWfmpqqqiu7t7UP/4jWURQXd3N1VVVQNq1+8ZgqRq4EJEvCtpPPA5YD2wHWgCHkmv2e9pbwd+LOlRMoPKs4C9aVD5rKTFwB7gXuAfcto0AbuBFcDO8H9hM+tH9k6brq6uUndl1KmqqqK2trb/FXMUcsloGtCSxhGuA7ZExM8l7Qa2SFoJvAncDRARhyRtAV4HeoDV6ZITwCrgKWA8sCNNAJuAZyS1kzkzaBzQXphZWaqsrGTGjBml7saY0W8gRMR+4LY89W5g6VXarAPW5am3AVeMP0TEOVKgmJlZafjhdmZmBjgQzMwscSCYmRngQDAzs8SBYGZmgAPBzMwSB4KZmQEOBDMzSxwIZmYGOBDMzCxxIJiZGeBAMDOzxIFgZmaAA8HMzBIHgpmZAQ4EMzNLHAhmZgY4EMzMLHEgFNPaiZnJzOwa5EAwMzPAgWBmZkm/gSBpuqRdkg5LOiTp/lRfK+m4pNfS9IWcNg9Japd0RNKynPpCSQfSsg2SlOrXS3o+1fdIqhuGfTUzsz4UcobQAzwQEbOBxcBqSXPSssciYkGaXgBIyxqBuUAD8LikcWn9J4BmYFaaGlJ9JfBORMwEHgPWD33XzMxsIPoNhIjojIhX0/xZ4DBQ00eT5cBzEXE+Io4C7cAiSdOACRGxOyICeBq4M6dNS5pvBZZmzx7MzGxkDGgMIV3KuQ3Yk0r3Sdov6UlJk1KtBngrp1lHqtWk+d71y9pERA9wBpic5/ObJbVJauvq6hpI183MrB8FB4KkG4CfAN+OiD+QufzzEWAB0An8ILtqnubRR72vNpcXIjZGRH1E1FdXVxfadTMzK0BBgSCpkkwYbI6InwJExMmIuBgRfwZ+BCxKq3cA03Oa1wInUr02T/2yNpIqgInA6cHskJmZDU4hdxkJ2AQcjohHc+rTcla7CziY5rcDjenOoRlkBo/3RkQncFbS4rTNe4FtOW2a0vwKYGcaZzAzsxFSUcA6S4CvAgckvZZqDwNfkrSAzKWdY8DfAkTEIUlbgNfJ3KG0OiIupnargKeA8cCONEEmcJ6R1E7mzKBxKDtlZmYD128gRMSvyH+N/4U+2qwD1uWptwHz8tTPAXf31xczMxs+/qaymZkBDgQzM0scCGZmBjgQzMwscSCYmRngQDAzs8SBYGZmgAPBzMwSB4KZmQEOBDMzSxwIZmYGOBDMzCxxIJiZGeBAMDOzxIFgZmaAA8HMzBIHgpmZAQ4EMzNLHAhmZgY4EMzMLHEgmJkZUEAgSJouaZekw5IOSbo/1W+S9KKkN9LrpJw2D0lql3RE0rKc+kJJB9KyDZKU6tdLej7V90iqG4Z9NTOzPhRyhtADPBARs4HFwGpJc4A1wMsRMQt4Ob0nLWsE5gINwOOSxqVtPQE0A7PS1JDqK4F3ImIm8Biwvgj7ZmZmA9BvIEREZ0S8mubPAoeBGmA50JJWawHuTPPLgeci4nxEHAXagUWSpgETImJ3RATwdK822W21AkuzZw9mZjYyBjSGkC7l3AbsAW6OiE7IhAYwNa1WA7yV06wj1WrSfO/6ZW0iogc4A0weSN/MzGxoCg4ESTcAPwG+HRF/6GvVPLXoo95Xm959aJbUJqmtq6urvy6bmdkAFBQIkirJhMHmiPhpKp9Ml4FIr6dSvQOYntO8FjiR6rV56pe1kVQBTARO9+5HRGyMiPqIqK+uri6k62ZmVqBC7jISsAk4HBGP5izaDjSl+SZgW069Md05NIPM4PHedFnprKTFaZv39mqT3dYKYGcaZzAzsxFSUcA6S4CvAgckvZZqDwOPAFskrQTeBO4GiIhDkrYAr5O5Q2l1RFxM7VYBTwHjgR1pgkzgPCOpncyZQePQdsvMzAaq30CIiF+R/xo/wNKrtFkHrMtTbwPm5amfIwWKmZmVhr+pbGZmgAPBzMwSB4KZmQEOBDMzSxwIZmYGOBDMzCxxIJiZGeBAMDOzxIFgZmaAA8HMzBIHgpmZAQ6EkbF2YmYyMxvFHAhmZgY4EMzMLHEgmJkZ4EAwM7PEgWBmZoADwczMEgeCmZkBDgQzM0scCGZmBhQQCJKelHRK0sGc2lpJxyW9lqYv5Cx7SFK7pCOSluXUF0o6kJZtkKRUv17S86m+R1JdkffRzMwKUMgZwlNAQ576YxGxIE0vAEiaAzQCc1ObxyWNS+s/ATQDs9KU3eZK4J2ImAk8Bqwf5L6YmdkQ9BsIEfEKcLrA7S0HnouI8xFxFGgHFkmaBkyIiN0REcDTwJ05bVrSfCuwNHv2YGZmI2coYwj3SdqfLilNSrUa4K2cdTpSrSbN965f1iYieoAzwOQh9MvMzAZhsIHwBPARYAHQCfwg1fP9ZR991PtqcwVJzZLaJLV1dXUNqMNmZta3QQVCRJyMiIsR8WfgR8CitKgDmJ6zai1wItVr89QvayOpApjIVS5RRcTGiKiPiPrq6urBdN3MzK5iUIGQxgSy7gKydyBtBxrTnUMzyAwe742ITuCspMVpfOBeYFtOm6Y0vwLYmcYZzMxsBFX0t4KkZ4HbgSmSOoDvAbdLWkDm0s4x4G8BIuKQpC3A60APsDoiLqZNrSJzx9J4YEeaADYBz0hqJ3Nm0FiE/TIzswHqNxAi4kt5ypv6WH8dsC5PvQ2Yl6d+Dri7v36Ymdnw8jeVzcwMcCCYmVniQDAzM8CBYGZmiQPBzMwAB4KZmSUOBDMzAxwIZmaWOBDMzAxwIJiZWdLvoyts8OrW/AKAY1Ul7oiZWQF8hmBmZoADwczMEgeCmZkBDgQzM0scCCVSt+YXlwadzcxGAweCmZkBDgQzM0scCGZmBjgQzMwscSCYmRngQDAzs6TfQJD0pKRTkg7m1G6S9KKkN9LrpJxlD0lql3RE0rKc+kJJB9KyDZKU6tdLej7V90iqK/I+jm5rJ2amxLejmlmpFHKG8BTQ0Ku2Bng5ImYBL6f3SJoDNAJzU5vHJY1LbZ4AmoFZacpucyXwTkTMBB4D1g92Z8zMbPD6DYSIeAU43au8HGhJ8y3AnTn15yLifEQcBdqBRZKmARMiYndEBPB0rzbZbbUCS7NnD9cK/1VvZmPBYMcQbo6IToD0OjXVa4C3ctbrSLWaNN+7flmbiOgBzgCT832opGZJbZLaurq6Btl1MzPLp9iDyvn+so8+6n21ubIYsTEi6iOivrq6epBdNDOzfAYbCCfTZSDS66lU7wCm56xXC5xI9do89cvaSKoAJnLlJSozMxtmgw2E7UBTmm8CtuXUG9OdQzPIDB7vTZeVzkpanMYH7u3VJrutFcDONM5gZmYjqN+f0JT0LHA7MEVSB/A94BFgi6SVwJvA3QARcUjSFuB1oAdYHREX06ZWkbljaTywI00Am4BnJLWTOTNoLMqemZnZgPQbCBHxpassWnqV9dcB6/LU24B5eernSIFiZmal428qX0N8e6uZDScHgpmZAQVcMrKM7F/mx6pK3BEzs2HiMwQzMwMcCGZmljgQzMwMcCCYmVniQDAzM8CBYGZmiQPBzMwAB8KY4W8xm9lQORDMzAxwIJiZWeJAMDMzwIFgZmaJA8HMzAAHgpmZJQ4EMzMDHAhjnr+fYGaFciCMVmsnZiYzsxHiQDAzM2CIgSDpmKQDkl6T1JZqN0l6UdIb6XVSzvoPSWqXdETSspz6wrSddkkbJGko/SqY/wo3M7ukGGcId0TEgoioT+/XAC9HxCzg5fQeSXOARmAu0AA8LmlcavME0AzMSlNDEfplZmYDMByXjJYDLWm+Bbgzp/5cRJyPiKNAO7BI0jRgQkTsjogAns5pY2ZmI2SogRDAv0naJ6k51W6OiE6A9Do11WuAt3LadqRaTZrvXb+CpGZJbZLaurq6hth18x1IZparYojtl0TECUlTgRcl/Ucf6+YbF4g+6lcWIzYCGwHq6+vzrmNmZoMzpDOEiDiRXk8BW4FFwMl0GYj0eiqt3gFMz2leC5xI9do8dTMzG0GDDgRJfyHpg9l54K+Ag8B2oCmt1gRsS/PbgUZJ10uaQWbweG+6rHRW0uJ0d9G9OW3MzGyEDOWS0c3A1nSHaAXw44j4F0m/AbZIWgm8CdwNEBGHJG0BXgd6gNURcTFtaxXwFDAe2JEmMzMbQYMOhIj4T+DWPPVuYOlV2qwD1uWptwHzBtsXMzMbOn9T2czMAAeC5eHbUc3KkwPBzMwAB4KZmSUOBDMzAxwI16YSPqXV4wtmY5cDwczMAAeCmZklDgQzMwMcCGZmljgQrCg82Gx27XMgmJkZ4EAwM7PEgWDDzpeTzK4NDoSxpkRfWDOza99Qf1N5zMn+JXusqsQdMTMbYT5DsJLxpSSz0cWBYKOSw8Js5DkQzMwMcCCUjxI+IbWYfOZgNnwcCDZmOCzMhmbUBIKkBklHJLVLWlPq/pSVMXL2cDV9BYVDxOx9oyIQJI0D/hH4a2AO8CVJc0rbKxvrQdGXwYaIw8WuZaPlewiLgPaI+E8ASc8By4HXh+PD/F2DIsgGxdozhdUH22aMuPT/3CNfLHjZaGhj5UURUeo+IGkF0BAR30jvvwp8MiLu67VeM9Cc3n4UODKEj50CvD2E9mOFj4OPAfgYZJXDcfhwRFTnWzBazhCUp3ZFUkXERmBjUT5QaouI+mJs61rm4+BjAD4GWeV+HEbFGALQAUzPeV8LnChRX8zMytJoCYTfALMkzZD0AaAR2F7iPpmZlZVRcckoInok3Qf8KzAOeDIiDg3zxxbl0tMY4OPgYwA+BlllfRxGxaCymZmV3mi5ZGRmZiXmQDAzM6BMA6EcH5Mh6UlJpyQdzKndJOlFSW+k10ml7ONwkzRd0i5JhyUdknR/qpfbcaiStFfSb9Nx+H6ql9VxgMxTEiT9X0k/T+/L7hjkKrtAKOPHZDwFNPSqrQFejohZwMvp/VjWAzwQEbOBxcDq9N++3I7DeeCzEXErsABokLSY8jsOAPcDh3Pel+MxuKTsAoGcx2RExJ+A7GMyxrSIeAU43au8HGhJ8y3AnSPZp5EWEZ0R8WqaP0vmH4Iayu84RES8l95Wpikos+MgqRb4IvBPOeWyOga9lWMg1ABv5bzvSLVydHNEdELmH0tgaon7M2Ik1QG3AXsow+OQLpW8BpwCXoyIcjwO/xv478Cfc2rldgwuU46BUNBjMmzsknQD8BPg2xHxh1L3pxQi4mJELCDzVIBFkuaVuEsjStLfAKciYl+p+zKalGMg+DEZ7zspaRpAej1V4v4MO0mVZMJgc0T8NJXL7jhkRcS7wC/JjC+V03FYAvwXScfIXDb+rKR/pryOwRXKMRD8mIz3bQea0nwTsK2EfRl2kgRsAg5HxKM5i8rtOFRLujHNjwc+B/wHZXQcIuKhiKiNiDoy/wbsjIivUEbHIJ+y/KaypC+QuX6YfUzGutL2aPhJeha4nczjfU8C3wN+BmwBPgS8CdwdEb0HnscMSZ8G/g9wgPevGz9MZhyhnI7DfDIDpuPI/FG4JSL+h6TJlNFxyJJ0O/D3EfE35XoMssoyEMzM7ErleMnIzMzycCCYmRngQDAzs8SBYGZmgAPBzMwSB4KZmQEOBDMzS/4/3IiXTC4n6T4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_df = pd.read_csv('data/train.csv')\n",
    "# 分词\n",
    "train_df['words_1'] = train_df['query1'].map(lambda x: ' '.join(jieba.lcut(x, cut_all=False)))\n",
    "train_df['words_2'] = train_df['query2'].map(lambda x: ' '.join(jieba.lcut(x, cut_all=False)))\n",
    "# 求不同词的个数\n",
    "train_df['words_diff'] = train_df.apply(lambda row: len(list(set(row['words_1'].split(' '))^set(row['words_2'].split(' ')))), axis=1)\n",
    "\n",
    "train_df = train_df.sort_values(by='words_diff')\n",
    "# 获取个值\n",
    "x = train_df['words_diff'].unique()\n",
    "count_0 = train_df[train_df['label'] == 0]['words_diff'].value_counts().to_dict()\n",
    "count_1 = train_df[train_df['label'] == 1]['words_diff'].value_counts().to_dict()\n",
    "y_0 = get_y(x, count_0)\n",
    "y_1 = get_y(x, count_1)\n",
    "# 画图\n",
    "width = 0.3\n",
    "plt.bar(x, y_0, width=0.3, label='label_0')\n",
    "plt.bar(x+width, y_1, width=0.3, label='label_1')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8537d662",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "#### 最长公用字符串长度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "7dfc3511",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD4CAYAAADsKpHdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAXkUlEQVR4nO3df4zUd53H8eerLO32bKGUHw2ynEuFVArSbdlQGszZisqemqNN6LlG7Xpi1hD0qunlCvUPuUtI5I9r74hHE5Qe2x62JasIUatWwHBGUlx6vfJL7EawncLBdmmxPUNl8X1/zGfrsAy7s7uzOzO7r0cyme+85/v5zvvb0r74fr7f+Y4iAjMzsytK3YCZmZUHB4KZmQEOBDMzSxwIZmYGOBDMzCypKnUDAzVp0qSora0tdRtmZhVl//79r0XE5HzvVWwg1NbW0tbWVuo2zMwqiqTfXe49TxmZmRngQDAzs8SBYGZmQAWfQzAzO3/+PJlMhnPnzpW6lbJTXV1NTU0NY8eOLXiMA8HMKlYmk+Haa6+ltrYWSaVup2xEBJ2dnWQyGWbMmFHwOE8ZmVnFOnfuHBMnTnQY9CCJiRMn9vvIyYFgZhXNYZDfQP65OBDMzAzwOQQzG0FqV/2wqNs7/o2PF3V75c5HCD2tGV/qDsysglxzzTW9vn/8+HHmzp3br21+7nOfo7W19bLvHzt2jNtvv51Zs2bxyU9+kj/+8Y/92v7lOBDMzCrMgw8+yFe/+lVeeuklJkyYwKZNm4qyXQeCmVkRvPXWWyxevJjbbruN97///Wzfvv2d97q6umhqamLevHksW7aMP/zhDwDs37+fD37wg8yfP58lS5Zw8uTJPj8nIti1axfLli0DoKmpie9///tF2QcHgplZEVRXV7Nt2zaef/55du/ezQMPPED3b9YfPXqU5uZmXnzxRcaNG8eGDRs4f/48X/7yl2ltbWX//v18/vOf52tf+1qfn9PZ2cl1111HVVX2FHBNTQ2vvvpqUfbBJ5XNzIogInjooYfYs2cPV1xxBa+++iqnTp0CYPr06SxatAiAz3zmM6xfv56GhgYOHjzIRz7yEQAuXLjA1KlTC/qcnop16a0DwcysCLZs2UJHRwf79+9n7Nix1NbWvvPFsJ7/w5ZERDBnzhz27t3br8+ZNGkSb7zxBl1dXVRVVZHJZHj3u99dlH3oMxAkVQN7gKvS+q0R8XVJ1wNPA7XAceBvI+L1NGY1sBy4APx9RPwk1ecDm4GrgR8B90dESLoKeByYD3QCn4yI40XZQzMbNUp5mejZs2eZMmUKY8eOZffu3fzud3/+2YGXX36ZvXv3cscdd/Dkk0/ygQ98gJtuuomOjo536ufPn+c3v/kNc+bM6fVzJHHXXXfR2tpKY2MjLS0tLF26tCj7UMg5hLeBD0XELUAd0CBpIbAK2BkRs4Cd6TWSbgYagTlAA7BB0pi0rUeBZmBWejSk+nLg9YiYCTwCrBv8rpmZDZ9Pf/rTtLW1UV9fz5YtW3jf+973znuzZ8+mpaWFefPmcebMGVasWMGVV15Ja2srDz74ILfccgt1dXX88pe/LOiz1q1bx8MPP8zMmTPp7Oxk+fLlRdmHPo8QIjth9VZ6OTY9AlgK3JnqLcDPgQdT/amIeBs4JqkdWCDpODAuIvYCSHocuBt4Jo1Zk7bVCnxTkiLfZJmZWRl5663s/x4nTZp02emfw4cP563X1dWxZ8+eS+qbN2/u9TNvvPFG9u3b179GC1DQVUaSxkh6ATgNPBsRzwE3RMRJgPQ8Ja0+DXglZ3gm1aal5Z71i8ZERBdwFpiYp49mSW2S2jo6OgraQTMzK0xBJ5Uj4gJQJ+k6YJuk3r52l+90d/RS721Mzz42AhsB6uvrh+foofuby2vODsvHmZl1u+eeezh27NhFtXXr1rFkyZIh+bx+XWUUEW9I+jnZuf9TkqZGxElJU8kePUD2b/7Tc4bVACdSvSZPPXdMRlIVMB440899MTMbUbZt2zasn9fnlJGkyenIAElXAx8Gfg3sAJrSak1A99fydgCNkq6SNIPsyeN9aVrpTUkLlb0G674eY7q3tQzY5fMHZmbDq5AjhKlAS7pS6Apga0T8QNJeYKuk5cDLwL0AEXFI0lbgMNAFrExTTgAr+PNlp8+kB8Am4Il0AvoM2auUzMxsGBVyldGLwK156p3A4suMWQuszVNvAy45/xAR50iBYmZmpeFvKpvZyFHs29ePsotJfHM7M7NBKMXvIXzzm99k5syZSOK1117r17Z740AwM6swixYt4mc/+xnvec97irpdB4KZWREM1+8hANx6663U1tYWfR8cCGZmRTBcv4cwlHxS2cysCIbr9xCGkgPBzKwIhuv3EIaSA8HMRo4SXiY6XL+HMJR8DsHMrAiG8/cQ1q9fT01NDZlMhnnz5vGFL3yhKPugSr1lUH19fbS1tRV/w2vGX/y3DN/t1KxsHTlyhNmzZ5e6jbKV75+PpP0RUZ9vfR8hmJkZ4HMIZmZlq6x/D8HMrNxExCVX8YwUg/k9hIGcDvCUkZlVrOrqajo7Owf0P7+RLCLo7Oykurq6X+N8hGBmFav7Shv/xvqlqqurqamp6XvFHA4EM6tYY8eOZcaMGaVuY8TwlJGZmQEOBDMzSxwIZmYGOBDMzCxxIJiZGeBAMDOzxIFgZmaAA8HMzJI+A0HSdEm7JR2RdEjS/am+RtKrkl5Ij4/ljFktqV3SUUlLcurzJR1I761XugGJpKskPZ3qz0mqHYJ9NTOzXhRyhNAFPBARs4GFwEpJN6f3HomIuvT4EUB6rxGYAzQAGySNSes/CjQDs9KjIdWXA69HxEzgEWDd4HfNzMz6o89AiIiTEfF8Wn4TOAJM62XIUuCpiHg7Io4B7cACSVOBcRGxN7J3onocuDtnTEtabgUWa6TevtDMrEz16xxCmsq5FXgulb4k6UVJj0makGrTgFdyhmVSbVpa7lm/aExEdAFngYl5Pr9ZUpukNt/MysysuAoOBEnXAN8FvhIRvyc7/fNeoA44CfxL96p5hkcv9d7GXFyI2BgR9RFRP3ny5EJbNzOzAhQUCJLGkg2DLRHxPYCIOBURFyLiT8C3gAVp9QwwPWd4DXAi1Wvy1C8aI6kKGA+cGcgOmZnZwBRylZGATcCRiHg4pz41Z7V7gINpeQfQmK4cmkH25PG+iDgJvClpYdrmfcD2nDFNaXkZsCv8ixdmZsOqkN9DWAR8Fjgg6YVUewj4lKQ6slM7x4EvAkTEIUlbgcNkr1BaGREX0rgVwGbgauCZ9IBs4DwhqZ3skUHjYHbKzMz6r89AiIhfkH+O/0e9jFkLrM1TbwPm5qmfA+7tqxczMxs6/qaymZkBDgQzM0scCGZmBjgQzMwscSCYmRngQDAzs8SBYGZmgAPBzMwSB4KZmQEOBDMzSxwIZmYGOBDMzCxxIAzEmvHZh5nZCOJAMDMzwIFgZmaJA8HMzAAHgpmZJQ4EMzMDHAhmZpY4EMzMDHAgmJlZ4kAwMzPAgWBmZokDwczMgAICQdJ0SbslHZF0SNL9qX69pGclvZSeJ+SMWS2pXdJRSUty6vMlHUjvrZekVL9K0tOp/pyk2iHYVzMz60UhRwhdwAMRMRtYCKyUdDOwCtgZEbOAnek16b1GYA7QAGyQNCZt61GgGZiVHg2pvhx4PSJmAo8A64qwb2Zm1g99BkJEnIyI59Pym8ARYBqwFGhJq7UAd6flpcBTEfF2RBwD2oEFkqYC4yJib0QE8HiPMd3bagUWdx89mJnZ8OjXOYQ0lXMr8BxwQ0SchGxoAFPSatOAV3KGZVJtWlruWb9oTER0AWeBiXk+v1lSm6S2jo6O/rRuZmZ9KDgQJF0DfBf4SkT8vrdV89Sil3pvYy4uRGyMiPqIqJ88eXJfLZuZWT8UFAiSxpINgy0R8b1UPpWmgUjPp1M9A0zPGV4DnEj1mjz1i8ZIqgLGA2f6uzNmZjZwhVxlJGATcCQiHs55awfQlJabgO059cZ05dAMsieP96VppTclLUzbvK/HmO5tLQN2pfMMZmY2TKoKWGcR8FnggKQXUu0h4BvAVknLgZeBewEi4pCkrcBhslcorYyIC2ncCmAzcDXwTHpANnCekNRO9sigcXC7ZWZm/dVnIETEL8g/xw+w+DJj1gJr89TbgLl56udIgWJmZqXhbyqbmRngQDAzs8SBYGZmgAPBzMwSB4KZmQEOBDMzSxwIZmYGOBDMzCxxIJiZGeBAMDOzxIFgZmaAA8HMzBIHQrGsGV/qDszMBsWBYGZmgAPBzMwSB4KZmQEOBDMzSxwIZmYGOBDMzCxxIJiZGeBAMDOzxIFgZmaAA8HMzBIHgpmZAQUEgqTHJJ2WdDCntkbSq5JeSI+P5by3WlK7pKOSluTU50s6kN5bL0mpfpWkp1P9OUm1Rd5HMzMrQCFHCJuBhjz1RyKiLj1+BCDpZqARmJPGbJA0Jq3/KNAMzEqP7m0uB16PiJnAI8C6Ae6LmZkNQp+BEBF7gDMFbm8p8FREvB0Rx4B2YIGkqcC4iNgbEQE8DtydM6YlLbcCi7uPHszMbPgM5hzClyS9mKaUJqTaNOCVnHUyqTYtLfesXzQmIrqAs8DEfB8oqVlSm6S2jo6OQbRuZmY9DTQQHgXeC9QBJ4F/SfV8f7OPXuq9jbm0GLExIuojon7y5Mn9atjMzHo3oECIiFMRcSEi/gR8C1iQ3soA03NWrQFOpHpNnvpFYyRVAeMpfIrKzMyKZECBkM4JdLsH6L4CaQfQmK4cmkH25PG+iDgJvClpYTo/cB+wPWdMU1peBuxK5xnMzGwYVfW1gqQngTuBSZIywNeBOyXVkZ3aOQ58ESAiDknaChwGuoCVEXEhbWoF2SuWrgaeSQ+ATcATktrJHhk0FmG/zMysn/oMhIj4VJ7ypl7WXwuszVNvA+bmqZ8D7u2rDzMzG1r+prKZmQEOBDMzSxwIZmYGOBDMzCxxIJiZGeBAMDOzxIFgZmaAA8HMzBIHgpmZAQ4EMzNLHAhmZgY4EMzMLHEgmJkZ4EAYWmvGl7oDM7OCORDMzAxwIJiZWeJAMDMzwIFgZmaJA8HMzAAHgpmZJQ4EMzMDHAhmZpY4EMzMDHAgmJlZ0mcgSHpM0mlJB3Nq10t6VtJL6XlCznurJbVLOippSU59vqQD6b31kpTqV0l6OtWfk1Rb5H00M7MCFHKEsBlo6FFbBeyMiFnAzvQaSTcDjcCcNGaDpDFpzKNAMzArPbq3uRx4PSJmAo8A6wa6M2ZmNnB9BkJE7AHO9CgvBVrScgtwd079qYh4OyKOAe3AAklTgXERsTciAni8x5jubbUCi7uPHszMbPgM9BzCDRFxEiA9T0n1acArOetlUm1aWu5Zv2hMRHQBZ4GJ+T5UUrOkNkltHR0dA2zdzMzyKfZJ5Xx/s49e6r2NubQYsTEi6iOifvLkyQNs0czM8hloIJxK00Ck59OpngGm56xXA5xI9Zo89YvGSKoCxnPpFJWZmQ2xgQbCDqApLTcB23PqjenKoRlkTx7vS9NKb0pamM4P3NdjTPe2lgG70nkGMzMbRlV9rSDpSeBOYJKkDPB14BvAVknLgZeBewEi4pCkrcBhoAtYGREX0qZWkL1i6WrgmfQA2AQ8Iamd7JFBY1H2zMzM+qXPQIiIT13mrcWXWX8tsDZPvQ2Ym6d+jhQolah21Q8BOF5d4kbMzAapz0Cw/nNImFkl8q0rzMwMcCAMm+6jBjOzcuVAMDMzwIFgZmaJA8HMzAAHgpmZJQ4EMzMDHAhmZpY4EIbTmvHZh5lZGXIglJC/m2Bm5cSBYGZmgAPBzMwSB4KZmQEOBDMzSxwIZcYnms2sVBwIZmYGOBDMzCxxIJiZGeBA6DfP8ZvZSOVAMDMzYLQHQgXcV6h21Q99VGJmw2J0B4KZmb3DgWBmZsAgA0HScUkHJL0gqS3Vrpf0rKSX0vOEnPVXS2qXdFTSkpz6/LSddknrJWkwfZmZWf8V4wjhroioi4j69HoVsDMiZgE702sk3Qw0AnOABmCDpDFpzKNAMzArPRqK0JeZmfXDUEwZLQVa0nILcHdO/amIeDsijgHtwAJJU4FxEbE3IgJ4PGeM5eETzWY2FAYbCAH8VNJ+Sc2pdkNEnARIz1NSfRrwSs7YTKpNS8s965eQ1CypTVJbR0fHIFs3M7NcVYMcvygiTkiaAjwr6de9rJvvvED0Ur+0GLER2AhQX1+fdx0zMxuYQR0hRMSJ9Hwa2AYsAE6laSDS8+m0egaYnjO8BjiR6jV56qODf2fZzMrEgANB0rskXdu9DHwUOAjsAJrSak3A9rS8A2iUdJWkGWRPHu9L00pvSlqYri66L2dMSVXSPH0l9Wpm5WkwU0Y3ANvSFaJVwHci4seSfgVslbQceBm4FyAiDknaChwGuoCVEXEhbWsFsBm4GngmPczMbBgNOBAi4rfALXnqncDiy4xZC6zNU28D5g60F8uvdtUPOf6Nj5e6DTOrEP6mspmZAQ4EMzNLHAijjE8+m9nlOBDMzAxwIJiZWeJAGOV8XyQz6+ZAMDMzwIFgefiowWx0ciCUI9/byMxKwIFgZmaAA8EK5Ckks5HPgWAD5pAwG1kcCGZmBjgQzMwscSBY0fhyVbPK5kAwMzPAgVA5KvS7CT5qMKscDgQzMwMG95vKI0r332KPV5e4ETOzEvERgg07TyGZlScHQiWr0PMK+TgkzErPgWBlyyFhNrwcCCPJmvEj6qihJ1+xZDa0HAhW0RwSZsVTNoEgqUHSUUntklaVup8RY4QfNeSTLyAKqTlcbLQri0CQNAb4d+CvgZuBT0m6eag+b9T/hz/KAmIw8v1ZGWi4jPo/d1b2yiIQgAVAe0T8NiL+CDwFLC1xT6NLvpDoWct3tHG5WiHbt6IezQw0vIq9fatciohS94CkZUBDRHwhvf4scHtEfKnHes1Ac3p5E3B0EB87CXhtEONLzf2XViX3X8m9g/sfrPdExOR8b5TLN5WVp3ZJUkXERmBjUT5QaouI+mJsqxTcf2lVcv+V3Du4/6FULlNGGWB6zusa4ESJejEzG5XKJRB+BcySNEPSlUAjsKPEPZmZjSplMWUUEV2SvgT8BBgDPBYRh4b4Y4sy9VRC7r+0Krn/Su4d3P+QKYuTymZmVnrlMmVkZmYl5kAwMzNgFAZCJd4iQ9Jjkk5LOphTu17Ss5JeSs8TStnj5UiaLmm3pCOSDkm6P9Urpf9qSfsk/U/q/59SvSL6h+ydACT9t6QfpNcV0zuApOOSDkh6QVJbqlXMPki6TlKrpF+n/w7uKNf+R1UgDPctMopoM9DQo7YK2BkRs4Cd6XU56gIeiIjZwEJgZfpnXin9vw18KCJuAeqABkkLqZz+Ae4HjuS8rqTeu90VEXU51+9X0j78G/DjiHgfcAvZfxfl2X9EjJoHcAfwk5zXq4HVpe6rwN5rgYM5r48CU9PyVOBoqXsscD+2Ax+pxP6BvwCeB26vlP7JfqdnJ/Ah4AeV+GcHOA5M6lGriH0AxgHHSBfwlHv/o+oIAZgGvJLzOpNqleiGiDgJkJ6nlLifPkmqBW4FnqOC+k9TLi8Ap4FnI6KS+v9X4B+BP+XUKqX3bgH8VNL+dPsaqJx9uBHoAP4jTdt9W9K7KNP+R1sgFHSLDCs+SdcA3wW+EhG/L3U//RERFyKijuzfthdImlvilgoi6RPA6YjYX+peBmlRRNxGdqp3paS/KnVD/VAF3AY8GhG3Av9HuUwP5THaAmEk3SLjlKSpAOn5dIn7uSxJY8mGwZaI+F4qV0z/3SLiDeDnZM/nVEL/i4C/kXSc7B2EPyTpP6mM3t8RESfS82lgG9m7I1fKPmSATDqqBGglGxBl2f9oC4SRdIuMHUBTWm4iOzdfdiQJ2AQciYiHc96qlP4nS7ouLV8NfBj4NRXQf0SsjoiaiKgl+2d9V0R8hgrovZukd0m6tnsZ+ChwkArZh4j4X+AVSTel0mLgMGXa/6j7prKkj5GdV+2+Rcba0nbUN0lPAneSvW3uKeDrwPeBrcBfAi8D90bEmRK1eFmSPgD8F3CAP89jP0T2PEIl9D8PaCH75+UKYGtE/LOkiVRA/90k3Qn8Q0R8opJ6l3Qj2aMCyE6/fCci1lbYPtQB3wauBH4L/B3pzxJl1v+oCwQzM8tvtE0ZmZnZZTgQzMwMcCCYmVniQDAzM8CBYGZmiQPBzMwAB4KZmSX/D47/LVvnkRbAAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_df = pd.read_csv('data/train.csv')\n",
    "\n",
    "# 求公用字符长度\n",
    "train_df['chars_diff'] = train_df.apply(lambda row: len(set(row['query1'])^set(row['query2'])), axis=1)\n",
    "\n",
    "train_df = train_df.sort_values(by='chars_diff')\n",
    "# 获取个值\n",
    "x = train_df['chars_diff'].unique()\n",
    "count_0 = train_df[train_df['label'] == 0]['chars_diff'].value_counts().to_dict()\n",
    "count_1 = train_df[train_df['label'] == 1]['chars_diff'].value_counts().to_dict()\n",
    "y_0 = get_y(x, count_0)\n",
    "y_1 = get_y(x, count_1)\n",
    "# 画图\n",
    "width = 0.3\n",
    "plt.bar(x, y_0, width=0.3, label='label_0')\n",
    "plt.bar(x+width, y_1, width=0.3, label='label_1')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ddf0f6d",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "#### TFIDF编码相似度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "f73e3307",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "不相似的距离 0.6841645162992568\n",
      "相似的距离 0.7919065411815118\n"
     ]
    }
   ],
   "source": [
    "train_df = pd.read_csv('data/val.csv')\n",
    "transformer = TfidfTransformer()\n",
    "vectorizer = TfidfVectorizer()\n",
    "# 分词\n",
    "train_df['words_1'] = train_df['query1'].map(lambda x: ' '.join(jieba.lcut(x, cut_all=False)))\n",
    "train_df['words_2'] = train_df['query2'].map(lambda x: ' '.join(jieba.lcut(x, cut_all=False)))\n",
    "\n",
    "#构建语料库\n",
    "seq_list = []\n",
    "for idx, row in train_df.iterrows():\n",
    "    seq_list.append(row['words_1'])\n",
    "    seq_list.append(row['words_2'])\n",
    "\n",
    "tfidf_list = transformer.fit_transform(vectorizer.fit_transform(seq_list)).toarray()\n",
    "# 构建tfidf特征\n",
    "tf_idf_1 = []\n",
    "tf_idf_2 = []\n",
    "for i in range(len(tfidf_list)):\n",
    "    if i%2 == 0:\n",
    "        tf_idf_1.append(tfidf_list[i])\n",
    "    else:\n",
    "        tf_idf_2.append(tfidf_list[i])\n",
    "train_df['tf_idf_1'] = tf_idf_1\n",
    "train_df['tf_idf_2'] = tf_idf_2\n",
    "# 计算余弦相似度 \n",
    "train_df['tf_idf_diff'] = train_df.apply(lambda row:1 - spatial.distance.cosine(row['tf_idf_1'], row['tf_idf_2']), axis=1)\n",
    "\n",
    "train_df\n",
    "print('不相似的距离', np.mean(train_df[train_df['label'] == 0]['tf_idf_diff'].tolist()))\n",
    "print('相似的距离', np.mean(train_df[train_df['label'] == 1]['tf_idf_diff'].tolist()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "832e13ff",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "#### 基于BERT的embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "95637613",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[2023-01-12 13:59:03,653] [    INFO]\u001b[0m - We use pattern recognition to recognize the Tokenizer class.\u001b[0m\n",
      "\u001b[32m[2023-01-12 13:59:03,653] [    INFO]\u001b[0m - We are using <class 'paddlenlp.transformers.bert.tokenizer.BertTokenizer'> to load 'D:/env/bert_model/hfl/chinese-bert-wwm-ext'.\u001b[0m\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Failed to import transformers.models.bert.modeling_bert because of the following error (look up to see its traceback):\n[WinError 127] 找不到指定的程序。 Error loading \"D:\\env\\Anaconda3\\lib\\site-packages\\torch\\lib\\cublas64_11.dll\" or one of its dependencies.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mOSError\u001b[0m                                   Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\transformers\\file_utils.py\u001b[0m in \u001b[0;36m_get_module\u001b[1;34m(self, module_name)\u001b[0m\n\u001b[0;32m   2776\u001b[0m         \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2777\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\".\"\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mmodule_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2778\u001b[0m         \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\importlib\\__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[1;34m(name, package)\u001b[0m\n\u001b[0;32m    126\u001b[0m             \u001b[0mlevel\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 127\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    128\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\importlib\\_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[1;34m(name, package, level)\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\importlib\\_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[1;34m(name, import_)\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\importlib\\_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[1;34m(name, import_)\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\importlib\\_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[1;34m(spec)\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\importlib\\_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[1;34m(self, module)\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\importlib\\_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[1;34m(f, *args, **kwds)\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     24\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 25\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     26\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m    123\u001b[0m                 \u001b[0merr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstrerror\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;34mf' Error loading \"{dll}\" or one of its dependencies.'\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 124\u001b[1;33m                 \u001b[1;32mraise\u001b[0m \u001b[0merr\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    125\u001b[0m             \u001b[1;32melif\u001b[0m \u001b[0mres\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mOSError\u001b[0m: [WinError 127] 找不到指定的程序。 Error loading \"D:\\env\\Anaconda3\\lib\\site-packages\\torch\\lib\\cublas64_11.dll\" or one of its dependencies.",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-85-b9264965d06f>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     15\u001b[0m \u001b[1;31m# 加载bert\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     16\u001b[0m \u001b[0mtokenizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mAutoTokenizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'D:/env/bert_model/hfl/chinese-bert-wwm-ext'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 17\u001b[1;33m \u001b[0membedding\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mAutoModel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'D:/env/bert_model/hfl/chinese-bert-wwm-ext'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0membeddings\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     18\u001b[0m \u001b[1;31m# 获取token\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     19\u001b[0m \u001b[0membed_1\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\transformers\\models\\auto\\auto_factory.py\u001b[0m in \u001b[0;36mfrom_pretrained\u001b[1;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[0;32m    444\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mmodel_class\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpretrained_model_name_or_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mmodel_args\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    445\u001b[0m         \u001b[1;32melif\u001b[0m \u001b[0mtype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mcls\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_model_mapping\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 446\u001b[1;33m             \u001b[0mmodel_class\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_get_model_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_model_mapping\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    447\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mmodel_class\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpretrained_model_name_or_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mmodel_args\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    448\u001b[0m         raise ValueError(\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\transformers\\models\\auto\\auto_factory.py\u001b[0m in \u001b[0;36m_get_model_class\u001b[1;34m(config, model_mapping)\u001b[0m\n\u001b[0;32m    358\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    359\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_get_model_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel_mapping\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 360\u001b[1;33m     \u001b[0msupported_models\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel_mapping\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    361\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msupported_models\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    362\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0msupported_models\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\transformers\\models\\auto\\auto_factory.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, key)\u001b[0m\n\u001b[0;32m    565\u001b[0m             \u001b[1;32mraise\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    566\u001b[0m         \u001b[0mmodel_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_model_mapping\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mmodel_type\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 567\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_load_attr_from_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_type\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    568\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    569\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_load_attr_from_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel_type\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mattr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\transformers\\models\\auto\\auto_factory.py\u001b[0m in \u001b[0;36m_load_attr_from_module\u001b[1;34m(self, model_type, attr)\u001b[0m\n\u001b[0;32m    571\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mmodule_name\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    572\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mmodule_name\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\".{module_name}\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"transformers.models\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 573\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mgetattribute_from_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mmodule_name\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mattr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    574\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    575\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mkeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\transformers\\models\\auto\\auto_factory.py\u001b[0m in \u001b[0;36mgetattribute_from_module\u001b[1;34m(module, attr)\u001b[0m\n\u001b[0;32m    533\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mattr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    534\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgetattribute_from_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0ma\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mattr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 535\u001b[1;33m     \u001b[1;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mattr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    536\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mattr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    537\u001b[0m     \u001b[1;31m# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\transformers\\file_utils.py\u001b[0m in \u001b[0;36m__getattr__\u001b[1;34m(self, name)\u001b[0m\n\u001b[0;32m   2765\u001b[0m             \u001b[0mvalue\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2766\u001b[0m         \u001b[1;32melif\u001b[0m \u001b[0mname\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_class_to_module\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2767\u001b[1;33m             \u001b[0mmodule\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_class_to_module\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2768\u001b[0m             \u001b[0mvalue\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2769\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\transformers\\file_utils.py\u001b[0m in \u001b[0;36m_get_module\u001b[1;34m(self, module_name)\u001b[0m\n\u001b[0;32m   2777\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\".\"\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mmodule_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2778\u001b[0m         \u001b[1;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2779\u001b[1;33m             raise RuntimeError(\n\u001b[0m\u001b[0;32m   2780\u001b[0m                 \u001b[1;34mf\"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\\n{e}\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2781\u001b[0m             ) from e\n",
      "\u001b[1;31mRuntimeError\u001b[0m: Failed to import transformers.models.bert.modeling_bert because of the following error (look up to see its traceback):\n[WinError 127] 找不到指定的程序。 Error loading \"D:\\env\\Anaconda3\\lib\\site-packages\\torch\\lib\\cublas64_11.dll\" or one of its dependencies."
     ]
    }
   ],
   "source": [
    "def get_mean_embed(seq):\n",
    "    '''\n",
    "        获取bert embedding\n",
    "    '''\n",
    "    input_ids = torch.tensor([tokenizer.encode_plus(seq)['input_ids']], dtype=torch.long)\n",
    "    embed = embedding(input_ids=input_ids).detach().numpy()\n",
    "    return np.mean(embed, axis=1)\n",
    "#     print('-----------------------')\n",
    "#     token = tokenizer.encode_plus(seq)['input_ids']\n",
    "#     embed = embedding(input_ids=torch.tensor([token], dtype=torch.long))[0].detach().numpy()\n",
    "#     print(np.mean(embed, axis=0))\n",
    "#     return np.mean(embed, axis=1)\n",
    "    \n",
    "train_df = pd.read_csv('data/train.csv')\n",
    "# 加载bert\n",
    "tokenizer = AutoTokenizer.from_pretrained('D:/env/bert_model/hfl/chinese-bert-wwm-ext')\n",
    "embedding = AutoModel.from_pretrained('D:/env/bert_model/hfl/chinese-bert-wwm-ext').embeddings\n",
    "# 获取token\n",
    "embed_1 = []\n",
    "embed_2 = []\n",
    "cos_sim = []\n",
    "for idx, row in tqdm(train_df.iterrows()):\n",
    "    bert_embed_1 = get_mean_embed(row['query1'])\n",
    "    bert_embed_2 = get_mean_embed(row['query2'])\n",
    "    \n",
    "\n",
    "    embed_1.append(bert_embed_1)\n",
    "    embed_2.append(bert_embed_2)\n",
    "    cos_sim.append(1 - spatial.distance.cosine(bert_embed_1, bert_embed_2))\n",
    "\n",
    "# 计算余弦相似度\n",
    "train_df['embed_1'] = embed_1\n",
    "train_df['embed_2'] = embed_2\n",
    "train_df['cos_sim'] = cos_sim\n",
    "print('不相似的距离', np.mean(train_df[train_df['label'] == 0]['cos_sim'].tolist()))\n",
    "print('相似的距离', np.mean(train_df[train_df['label'] == 1]['cos_sim'].tolist()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0dfe9e9",
   "metadata": {},
   "source": [
    "## 任务4：文本相似度（词向量与句子编码）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8de90804",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12500it [00:00, 22953.83it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(126597, 165612)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# jieba分词\n",
    "def cut(content):    \n",
    "    seg_list = []\n",
    "    try:\n",
    "        seg_list = jieba.lcut(content, cut_all=True)\n",
    "        \n",
    "    except AttributeError as ex:\n",
    "        print(content)\n",
    "        raise ex\n",
    "    return seg_list\n",
    "\n",
    "train_df = pd.read_csv('data/train.csv')\n",
    "# 分词\n",
    "train_df['words_a'] = train_df['query1'].apply(lambda x: cut(x))\n",
    "train_df['words_b'] = train_df['query2'].apply(lambda x: cut(x))\n",
    "# 训练词向量\n",
    "context = []\n",
    "for idx, row in tqdm(train_df.iterrows()):\n",
    "    context.append(row['words_a'])\n",
    "    context.append(row['words_b'])\n",
    "\n",
    "wv_model = Word2Vec(sentences=context, vector_size=100, window=5, min_count=1, workers=4)\n",
    "wv_model.train(context, total_examples=1, epochs=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cc517067",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12500it [00:00, 22350.16it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████| 12462/12462 [00:00<00:00, 1556517.57it/s]\n"
     ]
    }
   ],
   "source": [
    "# 统计全文的count\n",
    "count_list = []\n",
    "words_num = 0\n",
    "for idx, row in tqdm(train_df.iterrows()):\n",
    "    count_list += list(set(row['words_a']))\n",
    "    count_list += list(set(row['words_b']))\n",
    "    words_num +=2\n",
    "    \n",
    "    \n",
    "count = Counter(count_list)\n",
    "# 计算idf列表\n",
    "idf = {}\n",
    "for k, v in tqdm(dict(count).items()):\n",
    "    idf[k] = math.log(words_num/(v+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "b1be9e17",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12500it [00:01, 10549.48it/s]\n",
      "12500it [00:01, 9032.99it/s]\n",
      "12500it [00:02, 6212.43it/s]\n",
      "12500it [00:01, 6533.95it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>query1</th>\n",
       "      <th>query2</th>\n",
       "      <th>label</th>\n",
       "      <th>words_a</th>\n",
       "      <th>words_b</th>\n",
       "      <th>sif_wv_a</th>\n",
       "      <th>sif_wv_b</th>\n",
       "      <th>max_wv_a</th>\n",
       "      <th>max_wv_b</th>\n",
       "      <th>mean_wv_a</th>\n",
       "      <th>mean_wv_b</th>\n",
       "      <th>idf_wv_a</th>\n",
       "      <th>idf_wv_b</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>谁有狂三这张高清的</td>\n",
       "      <td>这张高清图，谁有</td>\n",
       "      <td>0</td>\n",
       "      <td>[谁, 有, 狂, 三, 这, 张, 高清, 的]</td>\n",
       "      <td>[这, 张, 高清, 图, ，, 谁, 有]</td>\n",
       "      <td>[0.00075101666, -0.0009414684, -0.0027995398, ...</td>\n",
       "      <td>[-0.00057386793, -0.0015908489, -0.0026177235,...</td>\n",
       "      <td>[-0.0077954414, 0.9239516, 0.44552782, 0.33902...</td>\n",
       "      <td>[-0.17484796, 0.8865976, 0.43001956, 0.3390268...</td>\n",
       "      <td>[-0.3810672, 0.56649673, 0.21015589, 0.1650711...</td>\n",
       "      <td>[-0.49064, 0.6627479, 0.20352013, 0.19425254, ...</td>\n",
       "      <td>[-1.3579595, 1.999914, 0.69180375, 0.5824405, ...</td>\n",
       "      <td>[-1.8708805, 2.5265408, 0.7487483, 0.7424466, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>英雄联盟什么英雄最好</td>\n",
       "      <td>英雄联盟最好英雄是什么</td>\n",
       "      <td>1</td>\n",
       "      <td>[英雄, 联盟, 什么, 英雄, 最好]</td>\n",
       "      <td>[英雄, 联盟, 最好, 英雄, 是, 什么]</td>\n",
       "      <td>[0.0010158978, 0.00027272105, -4.0359795e-05, ...</td>\n",
       "      <td>[0.0007458804, 0.00023201294, -0.00016308203, ...</td>\n",
       "      <td>[-0.08105441, 0.9698236, 0.83412725, 1.0974805...</td>\n",
       "      <td>[-0.08105441, 0.9698236, 0.83412725, 1.0974805...</td>\n",
       "      <td>[-0.24039476, 0.63472, 0.37772053, 0.34956783,...</td>\n",
       "      <td>[-0.39853027, 0.69034606, 0.25796387, 0.324222...</td>\n",
       "      <td>[-1.1932749, 2.4927368, 1.257377, 0.8515301, 0...</td>\n",
       "      <td>[-1.3220998, 2.344159, 0.95389634, 0.7640311, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>这是什么意思，被蹭网吗</td>\n",
       "      <td>我也是醉了，这是什么意思</td>\n",
       "      <td>0</td>\n",
       "      <td>[这, 是, 什么, 意思, ，, 被, 蹭, 网, 吗]</td>\n",
       "      <td>[我, 也, 是, 醉, 了, ，, 这, 是, 什么, 意思]</td>\n",
       "      <td>[-0.0014983066, -0.0003533112, -0.0009976879, ...</td>\n",
       "      <td>[-0.0011599427, -7.539801e-05, -0.0016003116, ...</td>\n",
       "      <td>[-0.024600787, 0.9698236, 0.83412725, 1.097480...</td>\n",
       "      <td>[-0.048710022, 0.9698236, 0.83412725, 1.097480...</td>\n",
       "      <td>[-0.5681006, 0.7457535, 0.23108628, 0.35262153...</td>\n",
       "      <td>[-0.6828134, 0.7691415, 0.10898125, 0.21714213...</td>\n",
       "      <td>[-1.7268541, 2.2139695, 0.6960715, 0.8554218, ...</td>\n",
       "      <td>[-1.8062147, 2.0117328, 0.30432898, 0.4488391,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>现在有什么动画片好看呢？</td>\n",
       "      <td>现在有什么好看的动画片吗？</td>\n",
       "      <td>1</td>\n",
       "      <td>[现在, 有, 什么, 动画, 动画片, 画片, 好看, 呢, ？]</td>\n",
       "      <td>[现在, 有, 什么, 好看, 的, 动画, 动画片, 画片, 吗, ？]</td>\n",
       "      <td>[-0.0006638123, -0.00333717, -0.0012107268, 0....</td>\n",
       "      <td>[-0.00037023518, -0.0029268796, -0.00092696166...</td>\n",
       "      <td>[-0.08105441, 0.9698236, 0.83412725, 1.0974805...</td>\n",
       "      <td>[-0.08105441, 0.9698236, 0.83412725, 1.0974805...</td>\n",
       "      <td>[-0.373247, 0.5750615, 0.2937528, 0.30965212, ...</td>\n",
       "      <td>[-0.37998262, 0.62548745, 0.34081635, 0.327185...</td>\n",
       "      <td>[-1.3610084, 1.8516116, 0.9059567, 0.8284683, ...</td>\n",
       "      <td>[-1.1786901, 1.7143263, 0.9125215, 0.78298885,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>请问晶达电子厂现在的工资待遇怎么样要求有哪些</td>\n",
       "      <td>三星电子厂工资待遇怎么样啊</td>\n",
       "      <td>0</td>\n",
       "      <td>[请问, 晶, 达, 电子, 电子厂, 现在, 的, 工资, 工资待遇, 待遇, 怎么, 怎...</td>\n",
       "      <td>[三星, 三星电子, 电子, 电子厂, 工资, 工资待遇, 待遇, 怎么, 怎么样, 啊]</td>\n",
       "      <td>[0.0003437847, 0.00028534047, -0.00011031516, ...</td>\n",
       "      <td>[-0.00035536848, 0.0021531135, 0.0006148936, 0...</td>\n",
       "      <td>[0.053401962, 1.2476405, 0.66729486, 0.4752677...</td>\n",
       "      <td>[-0.014251087, 1.2476405, 0.61681724, 0.475267...</td>\n",
       "      <td>[-0.24440448, 0.44179374, 0.21480398, 0.165082...</td>\n",
       "      <td>[-0.2544273, 0.40147883, 0.15589032, 0.1642858...</td>\n",
       "      <td>[-0.8873614, 1.5942813, 0.7564742, 0.60922605,...</td>\n",
       "      <td>[-0.9679738, 1.5586826, 0.62105405, 0.61899275...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12495</th>\n",
       "      <td>微店怎么开？怎么做代理？</td>\n",
       "      <td>微店怎样代理</td>\n",
       "      <td>1</td>\n",
       "      <td>[微, 店, 怎么, 开, ？, 怎么, 做, 代理, ？]</td>\n",
       "      <td>[微, 店, 怎样, 代理]</td>\n",
       "      <td>[-0.001035525, 9.592343e-05, 0.0005946201, 0.0...</td>\n",
       "      <td>[-0.00091870874, -5.8695674e-05, 0.0021628449,...</td>\n",
       "      <td>[-0.11761739, 1.2476405, 0.7010263, 0.47526774...</td>\n",
       "      <td>[-0.11761739, 1.1874398, 0.7010263, 0.36259052...</td>\n",
       "      <td>[-0.5269485, 0.77384484, 0.33383444, 0.297116,...</td>\n",
       "      <td>[-0.3117711, 0.61655945, 0.4177478, 0.20448962...</td>\n",
       "      <td>[-1.4481721, 2.1688704, 1.0642663, 0.7870712, ...</td>\n",
       "      <td>[-1.4010504, 2.6534526, 1.7952367, 0.9098671, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12496</th>\n",
       "      <td>小学科学三年级上</td>\n",
       "      <td>小学三年级科学</td>\n",
       "      <td>0</td>\n",
       "      <td>[小学, 学科, 科学, 三年, 三年级, 年级, 上]</td>\n",
       "      <td>[小学, 三年, 三年级, 年级, 科学]</td>\n",
       "      <td>[9.004027e-06, -0.00142885, 0.00081230234, -0....</td>\n",
       "      <td>[0.00040889718, -0.0011433773, 0.0020952262, -...</td>\n",
       "      <td>[-0.0062479675, 0.96860826, 0.28056538, 0.1461...</td>\n",
       "      <td>[-0.043239728, 0.41075003, 0.22162169, 0.14610...</td>\n",
       "      <td>[-0.20306472, 0.27727482, 0.11639612, -0.02241...</td>\n",
       "      <td>[-0.1049123, 0.19250993, 0.10605173, 0.0654489...</td>\n",
       "      <td>[-1.0425178, 1.5224932, 0.6920391, 0.042029757...</td>\n",
       "      <td>[-0.7344848, 1.3373768, 0.73671997, 0.44910973...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12497</th>\n",
       "      <td>冬眠是什么意思？</td>\n",
       "      <td>冬眠的意思是什么</td>\n",
       "      <td>1</td>\n",
       "      <td>[冬眠, 是, 什么, 意思, ？]</td>\n",
       "      <td>[冬眠, 的, 意思, 是, 什么]</td>\n",
       "      <td>[-0.0007360069, -0.0012951059, -0.00074847107,...</td>\n",
       "      <td>[-0.00067913055, -0.0012926087, -0.0006993339,...</td>\n",
       "      <td>[-0.016677497, 0.9698236, 0.83412725, 1.097480...</td>\n",
       "      <td>[-0.016677497, 0.9698236, 0.83412725, 1.097480...</td>\n",
       "      <td>[-0.61691874, 0.7294527, 0.13682225, 0.4141247...</td>\n",
       "      <td>[-0.55178684, 0.7358588, 0.19610454, 0.3864874...</td>\n",
       "      <td>[-1.2316244, 1.232954, 0.115233585, 0.59324133...</td>\n",
       "      <td>[-1.1142532, 1.1985903, 0.17986894, 0.5401553,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12498</th>\n",
       "      <td>天猫有假货吗</td>\n",
       "      <td>天猫卖假货吗</td>\n",
       "      <td>0</td>\n",
       "      <td>[天, 猫, 有假, 假货, 吗]</td>\n",
       "      <td>[天, 猫, 卖假, 假货, 吗]</td>\n",
       "      <td>[-0.001712067, -0.0018830914, -0.0011562603, 0...</td>\n",
       "      <td>[-0.0012003537, -0.0030034818, -0.0010703667, ...</td>\n",
       "      <td>[-0.01671742, 0.9001708, 0.6354656, 0.58082336...</td>\n",
       "      <td>[-0.011306549, 0.9001708, 0.6354656, 0.5808233...</td>\n",
       "      <td>[-0.22814763, 0.3581218, 0.21679118, 0.1769760...</td>\n",
       "      <td>[-0.22697529, 0.35625142, 0.21671212, 0.176217...</td>\n",
       "      <td>[-0.98722535, 1.5471647, 0.8786985, 0.6976306,...</td>\n",
       "      <td>[-0.9775583, 1.5306193, 0.8782679, 0.6917552, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12499</th>\n",
       "      <td>天兵天将是什么生肖？</td>\n",
       "      <td>天兵天将是指什么生肖</td>\n",
       "      <td>1</td>\n",
       "      <td>[天兵, 天兵天将, 天将, 是, 什么, 生肖, ？]</td>\n",
       "      <td>[天兵, 天兵天将, 天将, 是, 指, 什么, 生肖]</td>\n",
       "      <td>[-0.0028270306, -0.0009945175, -0.0019709868, ...</td>\n",
       "      <td>[-0.0040764976, -0.0011325236, -0.0029783533, ...</td>\n",
       "      <td>[-0.025301421, 0.9698236, 0.83412725, 1.097480...</td>\n",
       "      <td>[-0.025301421, 0.9698236, 0.83412725, 1.097480...</td>\n",
       "      <td>[-0.42586666, 0.5373887, 0.12877971, 0.3031228...</td>\n",
       "      <td>[-0.36814597, 0.474999, 0.1229522, 0.2690135, ...</td>\n",
       "      <td>[-1.0893434, 1.19944, 0.24725376, 0.5682123, -...</td>\n",
       "      <td>[-1.2632762, 1.3922013, 0.30616006, 0.62948805...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>12500 rows × 13 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                       query1         query2  label  \\\n",
       "0                   谁有狂三这张高清的       这张高清图，谁有      0   \n",
       "1                  英雄联盟什么英雄最好    英雄联盟最好英雄是什么      1   \n",
       "2                 这是什么意思，被蹭网吗   我也是醉了，这是什么意思      0   \n",
       "3                现在有什么动画片好看呢？  现在有什么好看的动画片吗？      1   \n",
       "4      请问晶达电子厂现在的工资待遇怎么样要求有哪些  三星电子厂工资待遇怎么样啊      0   \n",
       "...                       ...            ...    ...   \n",
       "12495            微店怎么开？怎么做代理？         微店怎样代理      1   \n",
       "12496                小学科学三年级上        小学三年级科学      0   \n",
       "12497                冬眠是什么意思？       冬眠的意思是什么      1   \n",
       "12498                  天猫有假货吗         天猫卖假货吗      0   \n",
       "12499              天兵天将是什么生肖？     天兵天将是指什么生肖      1   \n",
       "\n",
       "                                                 words_a  \\\n",
       "0                              [谁, 有, 狂, 三, 这, 张, 高清, 的]   \n",
       "1                                   [英雄, 联盟, 什么, 英雄, 最好]   \n",
       "2                          [这, 是, 什么, 意思, ，, 被, 蹭, 网, 吗]   \n",
       "3                     [现在, 有, 什么, 动画, 动画片, 画片, 好看, 呢, ？]   \n",
       "4      [请问, 晶, 达, 电子, 电子厂, 现在, 的, 工资, 工资待遇, 待遇, 怎么, 怎...   \n",
       "...                                                  ...   \n",
       "12495                     [微, 店, 怎么, 开, ？, 怎么, 做, 代理, ？]   \n",
       "12496                       [小学, 学科, 科学, 三年, 三年级, 年级, 上]   \n",
       "12497                                 [冬眠, 是, 什么, 意思, ？]   \n",
       "12498                                  [天, 猫, 有假, 假货, 吗]   \n",
       "12499                       [天兵, 天兵天将, 天将, 是, 什么, 生肖, ？]   \n",
       "\n",
       "                                             words_b  \\\n",
       "0                             [这, 张, 高清, 图, ，, 谁, 有]   \n",
       "1                            [英雄, 联盟, 最好, 英雄, 是, 什么]   \n",
       "2                   [我, 也, 是, 醉, 了, ，, 这, 是, 什么, 意思]   \n",
       "3              [现在, 有, 什么, 好看, 的, 动画, 动画片, 画片, 吗, ？]   \n",
       "4      [三星, 三星电子, 电子, 电子厂, 工资, 工资待遇, 待遇, 怎么, 怎么样, 啊]   \n",
       "...                                              ...   \n",
       "12495                                 [微, 店, 怎样, 代理]   \n",
       "12496                          [小学, 三年, 三年级, 年级, 科学]   \n",
       "12497                             [冬眠, 的, 意思, 是, 什么]   \n",
       "12498                              [天, 猫, 卖假, 假货, 吗]   \n",
       "12499                   [天兵, 天兵天将, 天将, 是, 指, 什么, 生肖]   \n",
       "\n",
       "                                                sif_wv_a  \\\n",
       "0      [0.00075101666, -0.0009414684, -0.0027995398, ...   \n",
       "1      [0.0010158978, 0.00027272105, -4.0359795e-05, ...   \n",
       "2      [-0.0014983066, -0.0003533112, -0.0009976879, ...   \n",
       "3      [-0.0006638123, -0.00333717, -0.0012107268, 0....   \n",
       "4      [0.0003437847, 0.00028534047, -0.00011031516, ...   \n",
       "...                                                  ...   \n",
       "12495  [-0.001035525, 9.592343e-05, 0.0005946201, 0.0...   \n",
       "12496  [9.004027e-06, -0.00142885, 0.00081230234, -0....   \n",
       "12497  [-0.0007360069, -0.0012951059, -0.00074847107,...   \n",
       "12498  [-0.001712067, -0.0018830914, -0.0011562603, 0...   \n",
       "12499  [-0.0028270306, -0.0009945175, -0.0019709868, ...   \n",
       "\n",
       "                                                sif_wv_b  \\\n",
       "0      [-0.00057386793, -0.0015908489, -0.0026177235,...   \n",
       "1      [0.0007458804, 0.00023201294, -0.00016308203, ...   \n",
       "2      [-0.0011599427, -7.539801e-05, -0.0016003116, ...   \n",
       "3      [-0.00037023518, -0.0029268796, -0.00092696166...   \n",
       "4      [-0.00035536848, 0.0021531135, 0.0006148936, 0...   \n",
       "...                                                  ...   \n",
       "12495  [-0.00091870874, -5.8695674e-05, 0.0021628449,...   \n",
       "12496  [0.00040889718, -0.0011433773, 0.0020952262, -...   \n",
       "12497  [-0.00067913055, -0.0012926087, -0.0006993339,...   \n",
       "12498  [-0.0012003537, -0.0030034818, -0.0010703667, ...   \n",
       "12499  [-0.0040764976, -0.0011325236, -0.0029783533, ...   \n",
       "\n",
       "                                                max_wv_a  \\\n",
       "0      [-0.0077954414, 0.9239516, 0.44552782, 0.33902...   \n",
       "1      [-0.08105441, 0.9698236, 0.83412725, 1.0974805...   \n",
       "2      [-0.024600787, 0.9698236, 0.83412725, 1.097480...   \n",
       "3      [-0.08105441, 0.9698236, 0.83412725, 1.0974805...   \n",
       "4      [0.053401962, 1.2476405, 0.66729486, 0.4752677...   \n",
       "...                                                  ...   \n",
       "12495  [-0.11761739, 1.2476405, 0.7010263, 0.47526774...   \n",
       "12496  [-0.0062479675, 0.96860826, 0.28056538, 0.1461...   \n",
       "12497  [-0.016677497, 0.9698236, 0.83412725, 1.097480...   \n",
       "12498  [-0.01671742, 0.9001708, 0.6354656, 0.58082336...   \n",
       "12499  [-0.025301421, 0.9698236, 0.83412725, 1.097480...   \n",
       "\n",
       "                                                max_wv_b  \\\n",
       "0      [-0.17484796, 0.8865976, 0.43001956, 0.3390268...   \n",
       "1      [-0.08105441, 0.9698236, 0.83412725, 1.0974805...   \n",
       "2      [-0.048710022, 0.9698236, 0.83412725, 1.097480...   \n",
       "3      [-0.08105441, 0.9698236, 0.83412725, 1.0974805...   \n",
       "4      [-0.014251087, 1.2476405, 0.61681724, 0.475267...   \n",
       "...                                                  ...   \n",
       "12495  [-0.11761739, 1.1874398, 0.7010263, 0.36259052...   \n",
       "12496  [-0.043239728, 0.41075003, 0.22162169, 0.14610...   \n",
       "12497  [-0.016677497, 0.9698236, 0.83412725, 1.097480...   \n",
       "12498  [-0.011306549, 0.9001708, 0.6354656, 0.5808233...   \n",
       "12499  [-0.025301421, 0.9698236, 0.83412725, 1.097480...   \n",
       "\n",
       "                                               mean_wv_a  \\\n",
       "0      [-0.3810672, 0.56649673, 0.21015589, 0.1650711...   \n",
       "1      [-0.24039476, 0.63472, 0.37772053, 0.34956783,...   \n",
       "2      [-0.5681006, 0.7457535, 0.23108628, 0.35262153...   \n",
       "3      [-0.373247, 0.5750615, 0.2937528, 0.30965212, ...   \n",
       "4      [-0.24440448, 0.44179374, 0.21480398, 0.165082...   \n",
       "...                                                  ...   \n",
       "12495  [-0.5269485, 0.77384484, 0.33383444, 0.297116,...   \n",
       "12496  [-0.20306472, 0.27727482, 0.11639612, -0.02241...   \n",
       "12497  [-0.61691874, 0.7294527, 0.13682225, 0.4141247...   \n",
       "12498  [-0.22814763, 0.3581218, 0.21679118, 0.1769760...   \n",
       "12499  [-0.42586666, 0.5373887, 0.12877971, 0.3031228...   \n",
       "\n",
       "                                               mean_wv_b  \\\n",
       "0      [-0.49064, 0.6627479, 0.20352013, 0.19425254, ...   \n",
       "1      [-0.39853027, 0.69034606, 0.25796387, 0.324222...   \n",
       "2      [-0.6828134, 0.7691415, 0.10898125, 0.21714213...   \n",
       "3      [-0.37998262, 0.62548745, 0.34081635, 0.327185...   \n",
       "4      [-0.2544273, 0.40147883, 0.15589032, 0.1642858...   \n",
       "...                                                  ...   \n",
       "12495  [-0.3117711, 0.61655945, 0.4177478, 0.20448962...   \n",
       "12496  [-0.1049123, 0.19250993, 0.10605173, 0.0654489...   \n",
       "12497  [-0.55178684, 0.7358588, 0.19610454, 0.3864874...   \n",
       "12498  [-0.22697529, 0.35625142, 0.21671212, 0.176217...   \n",
       "12499  [-0.36814597, 0.474999, 0.1229522, 0.2690135, ...   \n",
       "\n",
       "                                                idf_wv_a  \\\n",
       "0      [-1.3579595, 1.999914, 0.69180375, 0.5824405, ...   \n",
       "1      [-1.1932749, 2.4927368, 1.257377, 0.8515301, 0...   \n",
       "2      [-1.7268541, 2.2139695, 0.6960715, 0.8554218, ...   \n",
       "3      [-1.3610084, 1.8516116, 0.9059567, 0.8284683, ...   \n",
       "4      [-0.8873614, 1.5942813, 0.7564742, 0.60922605,...   \n",
       "...                                                  ...   \n",
       "12495  [-1.4481721, 2.1688704, 1.0642663, 0.7870712, ...   \n",
       "12496  [-1.0425178, 1.5224932, 0.6920391, 0.042029757...   \n",
       "12497  [-1.2316244, 1.232954, 0.115233585, 0.59324133...   \n",
       "12498  [-0.98722535, 1.5471647, 0.8786985, 0.6976306,...   \n",
       "12499  [-1.0893434, 1.19944, 0.24725376, 0.5682123, -...   \n",
       "\n",
       "                                                idf_wv_b  \n",
       "0      [-1.8708805, 2.5265408, 0.7487483, 0.7424466, ...  \n",
       "1      [-1.3220998, 2.344159, 0.95389634, 0.7640311, ...  \n",
       "2      [-1.8062147, 2.0117328, 0.30432898, 0.4488391,...  \n",
       "3      [-1.1786901, 1.7143263, 0.9125215, 0.78298885,...  \n",
       "4      [-0.9679738, 1.5586826, 0.62105405, 0.61899275...  \n",
       "...                                                  ...  \n",
       "12495  [-1.4010504, 2.6534526, 1.7952367, 0.9098671, ...  \n",
       "12496  [-0.7344848, 1.3373768, 0.73671997, 0.44910973...  \n",
       "12497  [-1.1142532, 1.1985903, 0.17986894, 0.5401553,...  \n",
       "12498  [-0.9775583, 1.5306193, 0.8782679, 0.6917552, ...  \n",
       "12499  [-1.2632762, 1.3922013, 0.30616006, 0.62948805...  \n",
       "\n",
       "[12500 rows x 13 columns]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 转换句向量\n",
    "# --------------------------------------- mean/max pooling --------------------------------------- #\n",
    "\n",
    "def text_to_wv(model, data, operation='max_pooling',key='wv'):\n",
    "\n",
    "    full_wv_a = []\n",
    "    full_wv_b = []\n",
    "    # 每句话转词向量表达\n",
    "    for idx, row in tqdm(train_df.iterrows()):\n",
    "        wv_a = []\n",
    "        words_a = row['words_a']\n",
    "     \n",
    "        for i in words_a:\n",
    "            wv_a.append(model.wv[i])\n",
    "        if operation == 'max_pooling':\n",
    "            full_wv_a.append(np.amax(wv_a, axis=0))\n",
    "        elif operation == 'mean_pooling':\n",
    "            full_wv_a.append(np.mean(wv_a, axis=0))\n",
    "            \n",
    "            \n",
    "        wv_b = []\n",
    "        words_b = row['words_b']\n",
    "        for i in words_b:\n",
    "            wv_b.append(model.wv[i])\n",
    "        if operation == 'max_pooling':\n",
    "            full_wv_b.append(np.amax(wv_b, axis=0))\n",
    "        elif operation == 'mean_pooling':\n",
    "            full_wv_b.append(np.mean(wv_b, axis=0))\n",
    "    data[key + '_a'] = full_wv_a\n",
    "    data[key + '_b'] = full_wv_b\n",
    "# --------------------------------------- mean/max pooling --------------------------------------- #\n",
    "    \n",
    "    \n",
    "# --------------------------------------- idf pooling --------------------------------------- #\n",
    "# idf加权的句向量\n",
    "def idf_to_wv(model, data, idf):\n",
    "\n",
    "    full_wv_a = []\n",
    "    full_wv_b = []\n",
    "    # 每句话转词向量表达\n",
    "    for idx, row in tqdm(train_df.iterrows()):        \n",
    "        wv_a = []\n",
    "        words_a = row['words_a']\n",
    "        \n",
    "        for i in words_a:\n",
    "            wv_a.append(model.wv[i] * idf[i])\n",
    "\n",
    "        full_wv_a.append(np.mean(wv_a, axis=0))\n",
    "            \n",
    "            \n",
    "        wv_b = []\n",
    "        words_b = row['words_b']\n",
    "        for i in words_b:\n",
    "            wv_b.append(model.wv[i] * idf[i])\n",
    "        \n",
    "        full_wv_b.append(np.mean(wv_b, axis=0))\n",
    "    data['idf_wv_a'] = full_wv_a\n",
    "    data['idf_wv_b'] = full_wv_b   \n",
    "# --------------------------------------- idf pooling --------------------------------------- #\n",
    "\n",
    "\n",
    "# --------------------------------------- sif pooling --------------------------------------- #\n",
    "\n",
    "# 计算主成分，npc为需要计算的主成分的个数\n",
    "def compute_pc(X, npc):\n",
    "    svd = TruncatedSVD(n_components=npc, n_iter=5, random_state=0)\n",
    "    svd.fit(X)\n",
    "    return svd.components_\n",
    "\n",
    "\n",
    "# 去除主成分\n",
    "def remove_pc(X, npc=1):\n",
    "    pc = compute_pc(X, npc)\n",
    "    if npc == 1:\n",
    "        XX = X - X.dot(pc.transpose()) * pc\n",
    "    else:\n",
    "        XX = X - X.dot(pc.transpose()).dot(pc)\n",
    "    return XX\n",
    "\n",
    "# 更新词权重\n",
    "def sif_weight(count, a=3e-5):\n",
    "    # 统计所有词频\n",
    "    word_num = 0\n",
    "    for k,v in dict(count).items():\n",
    "        word_num += v\n",
    "    # 更新权重\n",
    "    sif = {}\n",
    "    for k,v in dict(count).items():\n",
    "        sif[k] = a / (a + v/word_num)\n",
    "    return sif\n",
    "\n",
    "# sif加权的句向量\n",
    "def sif_to_wv(model, data, sif):\n",
    "\n",
    "    full_wv_a = []\n",
    "    full_wv_b = []\n",
    "    # 每句话转词向量表达\n",
    "    for idx, row in tqdm(train_df.iterrows()):\n",
    "        wv_a = []\n",
    "        words_a = row['words_a']\n",
    "        # 统计词向量\n",
    "        for i in words_a:\n",
    "            wv_a.append(model.wv[i] * sif[i])\n",
    "        # 记录结果\n",
    "        full_wv_a.append(np.mean(wv_a, axis=0))\n",
    "            \n",
    "            \n",
    "        wv_b = []\n",
    "        words_b = row['words_b']\n",
    "        for i in words_b:\n",
    "            wv_b.append(model.wv[i] * sif[i])\n",
    "        full_wv_b.append(np.mean(wv_b, axis=0))    \n",
    "    # 扣除第一主成分\n",
    "    full_wv_a = remove_pc(np.array(full_wv_a))\n",
    "    full_wv_b = remove_pc(np.array(full_wv_b))\n",
    "\n",
    "    data['sif_wv_a'] = list(full_wv_a)\n",
    "    data['sif_wv_b'] = list(full_wv_b)\n",
    "# --------------------------------------- sif pooling --------------------------------------- #\n",
    "\n",
    "    \n",
    "\n",
    "# 最大池化句向量\n",
    "text_to_wv(wv_model, train_df, 'max_pooling','max_wv')\n",
    "# 平均池化句向量\n",
    "text_to_wv(wv_model, train_df, 'mean_pooling','mean_wv')\n",
    "# idf加权平均句向量\n",
    "idf_to_wv(wv_model, train_df, idf)\n",
    "# 更新词权重\n",
    "sif = sif_weight(count)\n",
    "sif_to_wv(wv_model, train_df, sif)\n",
    "train_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "78902c44",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12500it [00:04, 3014.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean_sim\n",
      "不相似的距离 0.9830239296150207\n",
      "相似的距离 0.9887247297382354\n",
      "max_sim\n",
      "不相似的距离 0.9722367079353332\n",
      "相似的距离 0.9804226670265198\n",
      "idf_sim\n",
      "不相似的距离 0.9855579686450958\n",
      "相似的距离 0.9928204106330871\n",
      "sif_sim\n",
      "不相似的距离 0.8141846484244988\n",
      "相似的距离 0.944569531496428\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 计算各种相似度距离\n",
    "mean_sim_list = []\n",
    "max_sim_list = []\n",
    "idf_sim_list = []\n",
    "sif_sim_list = []\n",
    "for idx, row in tqdm(train_df.iterrows()):\n",
    "    mean_sim_list.append(1 - spatial.distance.cosine(row['mean_wv_a'], row['mean_wv_b']))\n",
    "    max_sim_list.append(1 - spatial.distance.cosine(row['max_wv_a'], row['max_wv_b']))\n",
    "    idf_sim_list.append(1 - spatial.distance.cosine(row['idf_wv_a'], row['idf_wv_b']))\n",
    "    sif_sim_list.append(1 - spatial.distance.cosine(row['sif_wv_a'], row['sif_wv_b']))\n",
    "    \n",
    "train_df['mean_sim'] = mean_sim_list\n",
    "train_df['max_sim'] = max_sim_list\n",
    "train_df['idf_sim'] = idf_sim_list\n",
    "train_df['sif_sim'] = sif_sim_list\n",
    "\n",
    "print('mean_sim')\n",
    "print('不相似的距离', np.mean(train_df[train_df['label'] == 0]['mean_sim'].tolist()))\n",
    "print('相似的距离', np.mean(train_df[train_df['label'] == 1]['mean_sim'].tolist()))\n",
    "print('max_sim')\n",
    "print('不相似的距离', np.mean(train_df[train_df['label'] == 0]['max_sim'].tolist()))\n",
    "print('相似的距离', np.mean(train_df[train_df['label'] == 1]['max_sim'].tolist()))\n",
    "print('idf_sim')\n",
    "print('不相似的距离', np.mean(train_df[train_df['label'] == 0]['idf_sim'].tolist()))\n",
    "print('相似的距离', np.mean(train_df[train_df['label'] == 1]['idf_sim'].tolist()))\n",
    "print('sif_sim')\n",
    "print('不相似的距离', np.mean(train_df[train_df['label'] == 0]['sif_sim'].tolist()))\n",
    "print('相似的距离', np.mean(train_df[train_df['label'] == 1]['sif_sim'].tolist()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7272e861-2450-4178-9ee5-03bd1564f675",
   "metadata": {},
   "source": [
    "## 任务5：文本匹配模型（LSTM孪生网络）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9225ae59-22f8-4e46-b966-258e7f3af2a6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>query1</th>\n",
       "      <th>query2</th>\n",
       "      <th>label</th>\n",
       "      <th>words_a</th>\n",
       "      <th>words_b</th>\n",
       "      <th>word2num_a</th>\n",
       "      <th>word2num_b</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>喜欢打篮球的男生喜欢什么样的女生</td>\n",
       "      <td>爱打篮球的男生喜欢什么样的女生</td>\n",
       "      <td>1</td>\n",
       "      <td>[喜欢, 打篮球, 篮球, 的, 男生, 喜欢, 什么, 什么样, 的, 女生]</td>\n",
       "      <td>[爱, 打篮球, 篮球, 的, 男生, 喜欢, 什么, 什么样, 的, 女生]</td>\n",
       "      <td>[0, 1, 2, 3, 4, 0, 5, 6, 3, 7]</td>\n",
       "      <td>[8, 1, 2, 3, 4, 0, 5, 6, 3, 7]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>我手机丢了，我想换个手机</td>\n",
       "      <td>我想买个新手机，求推荐</td>\n",
       "      <td>1</td>\n",
       "      <td>[我, 手机, 丢, 了, ，, 我, 想, 换, 个, 手机]</td>\n",
       "      <td>[我, 想买, 个, 新手, 新手机, 手机, ，, 求, 推荐]</td>\n",
       "      <td>[9, 10, 11, 12, 13, 9, 14, 15, 16, 10]</td>\n",
       "      <td>[9, 17, 16, 18, 19, 10, 13, 20, 21]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>大家觉得她好看吗</td>\n",
       "      <td>大家觉得跑男好看吗？</td>\n",
       "      <td>0</td>\n",
       "      <td>[大家, 觉得, 她, 好看, 吗]</td>\n",
       "      <td>[大家, 觉得, 跑, 男, 好看, 吗, ？]</td>\n",
       "      <td>[22, 23, 24, 25, 26]</td>\n",
       "      <td>[22, 23, 27, 28, 25, 26, 29]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>求秋色之空漫画全集</td>\n",
       "      <td>求秋色之空全集漫画</td>\n",
       "      <td>1</td>\n",
       "      <td>[求, 秋色, 之, 空, 漫画, 全集]</td>\n",
       "      <td>[求, 秋色, 之, 空, 全集, 漫画]</td>\n",
       "      <td>[20, 30, 31, 32, 33, 34]</td>\n",
       "      <td>[20, 30, 31, 32, 34, 33]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>晚上睡觉带着耳机听音乐有什么害处吗？</td>\n",
       "      <td>孕妇可以戴耳机听音乐吗?</td>\n",
       "      <td>0</td>\n",
       "      <td>[晚上, 睡觉, 带, 着, 耳机, 听音, 音乐, 有, 什么, 害处, 吗, ？]</td>\n",
       "      <td>[孕妇, 可以, 戴, 耳机, 听音, 音乐, 吗, ?]</td>\n",
       "      <td>[35, 36, 37, 38, 39, 40, 41, 42, 5, 43, 26, 29]</td>\n",
       "      <td>[44, 45, 46, 39, 40, 41, 26, 47]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               query1           query2  label  \\\n",
       "0    喜欢打篮球的男生喜欢什么样的女生  爱打篮球的男生喜欢什么样的女生      1   \n",
       "1        我手机丢了，我想换个手机      我想买个新手机，求推荐      1   \n",
       "2            大家觉得她好看吗       大家觉得跑男好看吗？      0   \n",
       "3           求秋色之空漫画全集        求秋色之空全集漫画      1   \n",
       "4  晚上睡觉带着耳机听音乐有什么害处吗？     孕妇可以戴耳机听音乐吗?      0   \n",
       "\n",
       "                                       words_a  \\\n",
       "0     [喜欢, 打篮球, 篮球, 的, 男生, 喜欢, 什么, 什么样, 的, 女生]   \n",
       "1             [我, 手机, 丢, 了, ，, 我, 想, 换, 个, 手机]   \n",
       "2                           [大家, 觉得, 她, 好看, 吗]   \n",
       "3                        [求, 秋色, 之, 空, 漫画, 全集]   \n",
       "4  [晚上, 睡觉, 带, 着, 耳机, 听音, 音乐, 有, 什么, 害处, 吗, ？]   \n",
       "\n",
       "                                   words_b  \\\n",
       "0  [爱, 打篮球, 篮球, 的, 男生, 喜欢, 什么, 什么样, 的, 女生]   \n",
       "1        [我, 想买, 个, 新手, 新手机, 手机, ，, 求, 推荐]   \n",
       "2                 [大家, 觉得, 跑, 男, 好看, 吗, ？]   \n",
       "3                    [求, 秋色, 之, 空, 全集, 漫画]   \n",
       "4            [孕妇, 可以, 戴, 耳机, 听音, 音乐, 吗, ?]   \n",
       "\n",
       "                                        word2num_a  \\\n",
       "0                   [0, 1, 2, 3, 4, 0, 5, 6, 3, 7]   \n",
       "1           [9, 10, 11, 12, 13, 9, 14, 15, 16, 10]   \n",
       "2                             [22, 23, 24, 25, 26]   \n",
       "3                         [20, 30, 31, 32, 33, 34]   \n",
       "4  [35, 36, 37, 38, 39, 40, 41, 42, 5, 43, 26, 29]   \n",
       "\n",
       "                            word2num_b  \n",
       "0       [8, 1, 2, 3, 4, 0, 5, 6, 3, 7]  \n",
       "1  [9, 17, 16, 18, 19, 10, 13, 20, 21]  \n",
       "2         [22, 23, 27, 28, 25, 26, 29]  \n",
       "3             [20, 30, 31, 32, 34, 33]  \n",
       "4     [44, 45, 46, 39, 40, 41, 26, 47]  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def build_vocab(data_df_list):\n",
    "    '''\n",
    "        构建词典映射\n",
    "        data_df_list: 读取数据df的列表\n",
    "    '''\n",
    "    word_index = 0\n",
    "    vocab_dict = {}\n",
    "    for data_df in data_df_list:\n",
    "        for idx, row in data_df.iterrows():\n",
    "            words_a = row['words_a']\n",
    "            words_b = row['words_b']\n",
    "            for word in words_a:\n",
    "                if word not in vocab_dict.keys():\n",
    "                    vocab_dict[word] = word_index\n",
    "                    word_index += 1\n",
    "            for word in words_b:\n",
    "                if word not in vocab_dict.keys():\n",
    "                    vocab_dict[word] = word_index\n",
    "                    word_index += 1\n",
    "    return vocab_dict\n",
    "\n",
    "\n",
    "def word2num(content, vocab):\n",
    "    '''\n",
    "        将中文词转换为词典数字\n",
    "        content: 分词后的文本句子\n",
    "        vocab：词典\n",
    "    '''\n",
    "    result = []\n",
    "    for word in content:\n",
    "        result.append(vocab[word])\n",
    "    return result\n",
    "\n",
    "# 读取数据\n",
    "data = pd.read_csv('data/train.csv')\n",
    "# 分词\n",
    "data['words_a'] = data['query1'].apply(lambda x: jieba.lcut(x, cut_all=True))\n",
    "data['words_b'] = data['query2'].apply(lambda x: jieba.lcut(x, cut_all=True))\n",
    "# 构建词典\n",
    "vocab = build_vocab([data])\n",
    "\n",
    "data['word2num_a'] = data['words_a'].apply(lambda x: word2num(x, vocab))\n",
    "data['word2num_b'] = data['words_b'].apply(lambda x: word2num(x, vocab))\n",
    "    \n",
    "data[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "de9a187d-3e68-402a-9229-c6982d97da76",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 自定义数据集\n",
    "class SimDataset(Dataset):\n",
    "    def __init__(self, df, max_seq_size):\n",
    "        super(SimDataset, self).__init__()\n",
    "        self.text_a = df['word2num_a']\n",
    "        self.text_b = df['word2num_b']\n",
    "        self.label = df['label']\n",
    "        self.len = len(df)\n",
    "        self.max_seq_size = max_seq_size\n",
    "        \n",
    "    def check_max_seq_size(self, vector):\n",
    "        if len(vector) >= self.max_seq_size:\n",
    "            return vector[:self.max_seq_size]\n",
    "        else:\n",
    "            new_vector = np.zeros((self.max_seq_size))\n",
    "            new_vector[:len(vector)] = vector\n",
    "            return new_vector \n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        vector_a = np.array(self.text_a.iloc[idx], dtype='int32')\n",
    "        vector_b = np.array(self.text_b.iloc[idx], dtype='int32')\n",
    "        \n",
    "        vector_a = self.check_max_seq_size(vector_a)\n",
    "        vector_b = self.check_max_seq_size(vector_b)\n",
    "        \n",
    "        label = np.array(self.label.iloc[idx]).astype(\"int64\")\n",
    "\n",
    "        return {'vector_a': paddle.to_tensor(vector_a, dtype='int32'),\n",
    "                'vector_b': paddle.to_tensor(vector_a, dtype='int32'),\n",
    "                'label': paddle.to_tensor(label, dtype='int32')}\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "47af84fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "class LinModel(nn.Layer):\n",
    "\n",
    "    def __init__(self, in_features, out_features):\n",
    "        super(LinModel, self).__init__()\n",
    "\n",
    "\n",
    "        self.fc_1 = nn.Sequential(\n",
    "            nn.Linear(in_features, 256),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.02)\n",
    "        )\n",
    "        self.fc_2 = nn.Sequential(\n",
    "            nn.Linear(256, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.02)\n",
    "        )\n",
    "        self.fc_3 = nn.Sequential(\n",
    "            nn.Linear(32, 4),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.02)\n",
    "        )\n",
    "        self.fc_4 = nn.Sequential(\n",
    "            nn.Linear(4, out_features),\n",
    "        )\n",
    "        self.softmax = nn.Softmax(1)\n",
    "\n",
    "    def forward(self, X):\n",
    "        X = self.fc_1(X)\n",
    "        X = self.fc_2(X)\n",
    "        X = self.fc_3(X)\n",
    "        output = self.fc_4(X)\n",
    "\n",
    "        return self.softmax(output)\n",
    "\n",
    "\n",
    "\n",
    "class Siamese(nn.Layer):\n",
    "    def __init__(self, config):\n",
    "        super(Siamese, self).__init__()\n",
    "        # 创建embed词典\n",
    "        self.embedding = nn.Embedding(config['vocab_size'], config['embed_size'])\n",
    "        # 创建 双向 两层 RNN\n",
    "        self.rnn = nn.LSTM(input_size=config['embed_size'], hidden_size=10, num_layers=2, direction='bidirectional')\n",
    "        \n",
    "        # 创建线性层\n",
    "        self.lin_model = LinModel(1200, 2)\n",
    " \n",
    "    def forward(self, words_a, words_b):\n",
    "        # 计算a\n",
    "        x_a = self.embedding(words_a)  # embedding转换\n",
    "        # rnn\n",
    "        x_a = paddle.transpose(x_a, perm=[1, 0, 2])\n",
    "        x_a, _ = self.rnn(x_a)\n",
    "        x_a = paddle.transpose(x_a, perm=[1, 0, 2])\n",
    "\n",
    "        # 计算b\n",
    "        x_b = self.embedding(words_b)\n",
    "        x_b = paddle.transpose(x_b, perm=[1, 0, 2])\n",
    "        x_b, _ = self.rnn(x_b)\n",
    "        x_b = paddle.transpose(x_b, perm=[1, 0, 2])\n",
    "        \n",
    "        '''\n",
    "            三种编码的交叉方式\n",
    "        '''\n",
    "\n",
    "        # 方法一：拼接\n",
    "        X_1 = paddle.concat([x_a, x_b], 2) \n",
    "        # 方法二：乘法\n",
    "        X_2 = x_a.multiply(x_b)\n",
    "        # 方法三：减法\n",
    "        X_3 = x_a.subtract(x_b)\n",
    "\n",
    "        # 拼接3种方式，展平张量\n",
    "        X = paddle.concat([X_1, X_2, X_3], 2) # (128, 27, 80)\n",
    "        X = paddle.reshape(X, shape=[X.shape[0], -1])\n",
    "        \n",
    "        # 线性推理\n",
    "        output = self.lin_model(X)\n",
    "        \n",
    "        return output\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0645db98",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, iter: 0, loss is: 0.692991316318512, acc: 0.546875\n",
      "epoch: 0, iter: 150, loss is: 0.6013166308403015, acc: 0.69140625\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[1;32mIn [9]\u001b[0m, in \u001b[0;36m<cell line: 27>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     27\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(config[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepochs\u001b[39m\u001b[38;5;124m'\u001b[39m]):\n\u001b[0;32m     28\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m iter_id, mini_batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(train_dataloader):\n\u001b[0;32m     29\u001b[0m         \u001b[38;5;66;03m# 计算\u001b[39;00m\n\u001b[1;32m---> 30\u001b[0m         output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmini_batch\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvector_a\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmini_batch\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvector_b\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     31\u001b[0m         loss \u001b[38;5;241m=\u001b[39m loss_fn(output, mini_batch[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlabel\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[0;32m     32\u001b[0m         \u001b[38;5;66;03m# 获取推理结果\u001b[39;00m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\paddle\\fluid\\dygraph\\layers.py:948\u001b[0m, in \u001b[0;36mLayer.__call__\u001b[1;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[0;32m    945\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m in_declarative_mode()) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks) \\\n\u001b[0;32m    946\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_post_hooks) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_built) \u001b[38;5;129;01mand\u001b[39;00m in_dygraph_mode() \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m in_profiler_mode()):\n\u001b[0;32m    947\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_once(\u001b[38;5;241m*\u001b[39minputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m--> 948\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(\u001b[38;5;241m*\u001b[39minputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m    949\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    950\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dygraph_call_func(\u001b[38;5;241m*\u001b[39minputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "Input \u001b[1;32mIn [6]\u001b[0m, in \u001b[0;36mSiamese.forward\u001b[1;34m(self, words_a, words_b)\u001b[0m\n\u001b[0;32m     58\u001b[0m x_b \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membedding(words_b)\n\u001b[0;32m     59\u001b[0m x_b \u001b[38;5;241m=\u001b[39m paddle\u001b[38;5;241m.\u001b[39mtranspose(x_b, perm\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m])\n\u001b[1;32m---> 60\u001b[0m x_b, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrnn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_b\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     61\u001b[0m x_b \u001b[38;5;241m=\u001b[39m paddle\u001b[38;5;241m.\u001b[39mtranspose(x_b, perm\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m])\n\u001b[0;32m     63\u001b[0m \u001b[38;5;124;03m'''\u001b[39;00m\n\u001b[0;32m     64\u001b[0m \u001b[38;5;124;03m    三种编码的交叉方式\u001b[39;00m\n\u001b[0;32m     65\u001b[0m \u001b[38;5;124;03m'''\u001b[39;00m\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\paddle\\fluid\\dygraph\\layers.py:948\u001b[0m, in \u001b[0;36mLayer.__call__\u001b[1;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[0;32m    945\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m in_declarative_mode()) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks) \\\n\u001b[0;32m    946\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_post_hooks) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_built) \u001b[38;5;129;01mand\u001b[39;00m in_dygraph_mode() \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m in_profiler_mode()):\n\u001b[0;32m    947\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_once(\u001b[38;5;241m*\u001b[39minputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m--> 948\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(\u001b[38;5;241m*\u001b[39minputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m    949\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    950\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dygraph_call_func(\u001b[38;5;241m*\u001b[39minputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\paddle\\nn\\layer\\rnn.py:1082\u001b[0m, in \u001b[0;36mRNNBase.forward\u001b[1;34m(self, inputs, initial_states, sequence_length)\u001b[0m\n\u001b[0;32m   1076\u001b[0m     initial_states \u001b[38;5;241m=\u001b[39m [initial_states] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[0;32m   1077\u001b[0m         initial_states, paddle\u001b[38;5;241m.\u001b[39mstatic\u001b[38;5;241m.\u001b[39mVariable) \u001b[38;5;28;01melse\u001b[39;00m initial_states\n\u001b[0;32m   1079\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcould_use_cudnn \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m paddle\u001b[38;5;241m.\u001b[39mdevice\u001b[38;5;241m.\u001b[39mis_compiled_with_rocm()\n\u001b[0;32m   1080\u001b[0m                              \u001b[38;5;129;01mor\u001b[39;00m sequence_length \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m   1081\u001b[0m     \u001b[38;5;66;03m# Add CPU kernel and dispatch in backend later\u001b[39;00m\n\u001b[1;32m-> 1082\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cudnn_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minitial_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msequence_length\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1084\u001b[0m states \u001b[38;5;241m=\u001b[39m split_states(initial_states, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_directions \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m,\n\u001b[0;32m   1085\u001b[0m                       \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate_components)\n\u001b[0;32m   1086\u001b[0m final_states \u001b[38;5;241m=\u001b[39m []\n",
      "File \u001b[1;32mE:\\env\\Anaconda\\lib\\site-packages\\paddle\\nn\\layer\\rnn.py:1016\u001b[0m, in \u001b[0;36mRNNBase._cudnn_impl\u001b[1;34m(self, inputs, initial_states, sequence_length)\u001b[0m\n\u001b[0;32m   1013\u001b[0m     inputs \u001b[38;5;241m=\u001b[39m paddle\u001b[38;5;241m.\u001b[39mtensor\u001b[38;5;241m.\u001b[39mtranspose(inputs, [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m])\n\u001b[0;32m   1015\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m in_dynamic_mode():\n\u001b[1;32m-> 1016\u001b[0m     _, _, out, state \u001b[38;5;241m=\u001b[39m \u001b[43m_legacy_C_ops\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrnn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1017\u001b[0m \u001b[43m        \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minitial_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_all_weights\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msequence_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1018\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dropout_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate_components\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mdropout_prob\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1019\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mis_bidirec\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_directions\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1020\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43minput_size\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mhidden_size\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1021\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mnum_layers\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmode\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mis_test\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1022\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1023\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   1024\u001b[0m     out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_helper\u001b[38;5;241m.\u001b[39mcreate_variable_for_type_inference(inputs\u001b[38;5;241m.\u001b[39mdtype)\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "paddle.device.set_device('gpu')\n",
    "config = {\n",
    "    'batch_size': 256,\n",
    "    'vocab_size': len(vocab),\n",
    "    'max_seq_len': 15,\n",
    "    'embed_size': 100,\n",
    "    'lr': 1e-3,\n",
    "    'epochs': 1\n",
    "}\n",
    "# 创建数据集\n",
    "train_dataset = SimDataset(data, config['max_seq_len'])\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, drop_last=True)\n",
    "\n",
    "# 创建模型\n",
    "model = Siamese(config)\n",
    "\n",
    "\n",
    "# 创建损失函数\n",
    "loss_fn = nn.loss.CrossEntropyLoss()\n",
    "# 定义优化器\n",
    "beta1 = paddle.to_tensor([0.9], dtype=\"float32\")\n",
    "beta2 = paddle.to_tensor([0.99], dtype=\"float32\")\n",
    "\n",
    "opt = paddle.optimizer.AdamW(learning_rate=config['lr'],\n",
    "        parameters=model.parameters(),\n",
    "        weight_decay=0.01)\n",
    "for epoch in range(config['epochs']):\n",
    "    for iter_id, mini_batch in enumerate(train_dataloader):\n",
    "        # 计算\n",
    "        output = model(mini_batch['vector_a'], mini_batch['vector_b'])\n",
    "        loss = loss_fn(output, mini_batch['label'])\n",
    "        # 获取推理结果\n",
    "        y_pred = paddle.argmax(output, axis=1)\n",
    "        y_pred = y_pred.numpy()\n",
    "        y_true = mini_batch['label'].numpy()\n",
    "\n",
    "        if iter_id % 150 == 0:\n",
    "            print('epoch: {}, iter: {}, loss is: {}, acc: {}'.format(epoch, iter_id, loss.numpy()[0], np.sum(y_pred == y_true)/len(y_true)))\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.clear_grad()\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54cd88c8-3840-4568-996a-0f12c3853652",
   "metadata": {},
   "source": [
    "## 任务6：文本匹配模型（BERT模型）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ef26494d-c548-4365-a2d5-cbe978a935fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 自定义数据集\n",
    "class SimDataset(Dataset):\n",
    "    def __init__(self, data_dict):\n",
    "        super(SimDataset, self).__init__()\n",
    "        self.input_ids = data_dict['input_ids']\n",
    "        self.token_type_ids = data_dict['token_type_ids']\n",
    "        self.attention_mask = data_dict['attention_mask']\n",
    "        self.labels = data_dict['labels']\n",
    "        self.len = len(self.input_ids)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        data = (self.input_ids[index],\n",
    "                self.token_type_ids[index],\n",
    "                self.attention_mask[index],\n",
    "                self.labels[index])\n",
    "\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.len\n",
    "\n",
    "# 统一处理数据\n",
    "class Collator:\n",
    "\n",
    "    def __init__(self, tokenizer, max_seq_len):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_seq_len = max_seq_len\n",
    "\n",
    "    def pad(self, input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len):\n",
    "        # 初始化填充长度\n",
    "        input_ids = paddle.zeros((len(input_ids_list), max_seq_len), dtype='int64')\n",
    "        token_type_ids = paddle.zeros_like(input_ids)\n",
    "        attention_mask = paddle.zeros_like(input_ids)\n",
    "        # 遍历获取输入\n",
    "        for i in range(len(input_ids_list)):\n",
    "            seq_len = len(input_ids_list[i])\n",
    "\n",
    "            if seq_len < max_seq_len:  # 如果小于最大长度\n",
    "                input_ids[i, :seq_len] = paddle.to_tensor(input_ids_list[i], dtype='int64')\n",
    "                token_type_ids[i, :seq_len] = paddle.to_tensor(token_type_ids_list[i], dtype='int64')\n",
    "                attention_mask[i, :seq_len] = paddle.to_tensor(attention_mask_list[i], dtype='int64')\n",
    "            else:  # 如果大于或等于\n",
    "                # 最后一位加上tokenizer的特殊占位\n",
    "                input_ids[i] = paddle.to_tensor(\n",
    "                    input_ids_list[i][:max_seq_len - 1] + [self.tokenizer.sep_token_id], dtype='int64')\n",
    "                token_type_ids[i] = paddle.to_tensor(\n",
    "                    token_type_ids_list[i][:max_seq_len], dtype='int64')\n",
    "                attention_mask[i] = paddle.to_tensor(\n",
    "                    attention_mask_list[i][:max_seq_len], dtype='int64')\n",
    "        # 格式化输出\n",
    "        labels = paddle.to_tensor([[label]for label in labels_list], dtype='int64')\n",
    "\n",
    "        return input_ids, token_type_ids, attention_mask, labels\n",
    "\n",
    "    def __call__(self, examples):\n",
    "        # 获取数据\n",
    "        input_ids_list, token_type_ids_list, attention_mask_list, labels_list = list(zip(*examples))\n",
    "        # 求句子最大长度\n",
    "        cur_seq_len = max([len(ids) for ids in input_ids_list])  # 当前数据最大长度\n",
    "        max_seq_len = min(cur_seq_len, self.max_seq_len)  # 最大长度\n",
    "        # 填充句子\n",
    "        input_ids, token_type_ids, attention_mask, labels = self.pad(input_ids_list, token_type_ids_list,\n",
    "                                                                     attention_mask_list, labels_list, max_seq_len)\n",
    "        # 返回结果\n",
    "        data = {\n",
    "            'input_ids': input_ids,\n",
    "            'token_type_ids': token_type_ids,\n",
    "            'attention_mask': attention_mask,\n",
    "            'labels': labels,\n",
    "        }\n",
    "        return data\n",
    "\n",
    "\n",
    "# 创建dataloader\n",
    "def create_dataloader(config):\n",
    "    # 读取数据\n",
    "    train = pd.read_csv('data/train.csv')[:2000]\n",
    "    data_dict = {'input_ids':[],'token_type_ids':[],'attention_mask':[],'labels':[]}\n",
    "    for i, row in tqdm(train.iterrows()):\n",
    "        seq_a = row[0]\n",
    "        seq_b = row[1]\n",
    "        label = row[2]\n",
    "        inputs_dict = config.tokenizer.encode(seq_a, seq_b, return_special_tokens_mask=True, return_token_type_ids=True,\n",
    "                                                return_attention_mask=True)\n",
    "        data_dict['input_ids'].append(inputs_dict['input_ids'])\n",
    "        data_dict['token_type_ids'].append(inputs_dict['token_type_ids'])\n",
    "        data_dict['attention_mask'].append(inputs_dict['attention_mask'])\n",
    "        data_dict['labels'].append(label)\n",
    "    # 构建dataset\n",
    "    train_dataset = SimDataset(data_dict)\n",
    "\n",
    "    # 构建dataloader\n",
    "    collate_fn = Collator(config.tokenizer, config.max_seq_len)\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, collate_fn=collate_fn, shuffle=True,\n",
    "                                  num_workers=0)\n",
    "    return train_dataloader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fe2ea3f7-b3d8-452e-8788-71693e251599",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2000it [00:00, 2845.96it/s]\n"
     ]
    }
   ],
   "source": [
    "class Config:\n",
    "    # 数据加载部分\n",
    "    max_seq_len = 15*2  # 句子长度\n",
    "    # 模型部分\n",
    "    model_path = 'ernie-3.0-base-zh'  # 本地模型路径\n",
    "    tokenizer = None  # tokenizer对象\n",
    "    # 训练部分\n",
    "    learning_rate = 3e-5\n",
    "    batch_size = 16  # batch大小\n",
    "    epochs = 15  # 训练次数\n",
    "    print_loss = 50  # 打印loss次数\n",
    "    num_labels = 2  # 分类数\n",
    "\n",
    "config = Config()\n",
    "# 加载模型\n",
    "config.tokenizer = AutoTokenizer.from_pretrained(config.model_path)\n",
    "# 创建数据集\n",
    "train_dataloader = create_dataloader(config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d0a49650-85a7-4900-a3c3-6446a5537f0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0, iter_id:0, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [0.73235011]), acc:0.375\n",
      "epoch:0, iter_id:50, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [0.51563501]), acc:0.7218137254901961\n",
      "epoch:0, iter_id:100, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [0.19683793]), acc:0.7790841584158416\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-5-544cc3cbde1d>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     14\u001b[0m         \u001b[0mattention_mask\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmini_batch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'attention_mask'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     15\u001b[0m         \u001b[0mlabels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmini_batch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'labels'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 16\u001b[1;33m         \u001b[0mlogits\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtoken_type_ids\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtoken_type_ids\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     17\u001b[0m         \u001b[1;31m# 计算损失值\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     18\u001b[0m         \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddle\\fluid\\dygraph\\layers.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[0;32m    946\u001b[0m             \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_post_hooks\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_built\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0min_dygraph_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0min_profiler_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    947\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_build_once\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 948\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    949\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    950\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dygraph_call_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddlenlp\\transformers\\ernie\\modeling.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, token_type_ids, position_ids, attention_mask, inputs_embeds, labels, output_hidden_states, output_attentions, return_dict)\u001b[0m\n\u001b[0;32m    458\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    459\u001b[0m         \"\"\"\n\u001b[1;32m--> 460\u001b[1;33m         outputs = self.ernie(\n\u001b[0m\u001b[0;32m    461\u001b[0m             \u001b[0minput_ids\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    462\u001b[0m             \u001b[0mtoken_type_ids\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtoken_type_ids\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddle\\fluid\\dygraph\\layers.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[0;32m    946\u001b[0m             \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_post_hooks\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_built\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0min_dygraph_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0min_profiler_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    947\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_build_once\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 948\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    949\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    950\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dygraph_call_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddlenlp\\transformers\\ernie\\modeling.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, token_type_ids, position_ids, attention_mask, task_type_ids, past_key_values, inputs_embeds, use_cache, output_hidden_states, output_attentions, return_dict)\u001b[0m\n\u001b[0;32m    353\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    354\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_use_cache\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0muse_cache\u001b[0m  \u001b[1;31m# To be consistent with HF\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 355\u001b[1;33m         encoder_outputs = self.encoder(\n\u001b[0m\u001b[0;32m    356\u001b[0m             \u001b[0membedding_output\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    357\u001b[0m             \u001b[0msrc_mask\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddle\\fluid\\dygraph\\layers.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[0;32m    946\u001b[0m             \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_post_hooks\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_built\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0min_dygraph_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0min_profiler_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    947\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_build_once\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 948\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    949\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    950\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dygraph_call_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddlenlp\\transformers\\model_outputs.py\u001b[0m in \u001b[0;36m_transformer_encoder_fwd\u001b[1;34m(self, src, src_mask, cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m    293\u001b[0m             )\n\u001b[0;32m    294\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 295\u001b[1;33m             layer_outputs = mod(\n\u001b[0m\u001b[0;32m    296\u001b[0m                 \u001b[0moutput\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    297\u001b[0m                 \u001b[0msrc_mask\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msrc_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddle\\fluid\\dygraph\\layers.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[0;32m    946\u001b[0m             \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_post_hooks\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_built\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0min_dygraph_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0min_profiler_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    947\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_build_once\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 948\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    949\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    950\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dygraph_call_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddlenlp\\transformers\\model_outputs.py\u001b[0m in \u001b[0;36m_transformer_encoder_layer_fwd\u001b[1;34m(self, src, src_mask, cache, output_attentions)\u001b[0m\n\u001b[0;32m     96\u001b[0m         \u001b[0msrc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     97\u001b[0m     \u001b[0msrc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mactivation\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 98\u001b[1;33m     \u001b[0msrc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mresidual\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdropout2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     99\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnormalize_before\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    100\u001b[0m         \u001b[0msrc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddle\\fluid\\dygraph\\layers.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[0;32m    946\u001b[0m             \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_post_hooks\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_built\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0min_dygraph_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;32mnot\u001b[0m \u001b[0min_profiler_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    947\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_build_once\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 948\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    949\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    950\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dygraph_call_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddle\\nn\\layer\\common.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    763\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    764\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 765\u001b[1;33m         out = F.dropout(\n\u001b[0m\u001b[0;32m    766\u001b[0m             \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    767\u001b[0m             \u001b[0mp\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mp\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\paddle\\nn\\functional\\common.py\u001b[0m in \u001b[0;36mdropout\u001b[1;34m(x, p, axis, training, mode, name)\u001b[0m\n\u001b[0;32m   1132\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1133\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0min_dygraph_mode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1134\u001b[1;33m                 out, mask = _C_ops.dropout(\n\u001b[0m\u001b[0;32m   1135\u001b[0m                     \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1136\u001b[0m                     \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# 创建模型\n",
    "model = AutoModelForSequenceClassification.from_pretrained(config.model_path,num_classes=2)\n",
    "# 定义优化器\n",
    "opt = paddle.optimizer.AdamW(learning_rate=config.learning_rate, parameters=model.parameters())\n",
    "# 定义损失函数\n",
    "loss_fn = nn.loss.CrossEntropyLoss()\n",
    "metric = paddle.metric.Accuracy()\n",
    "# 遍历训练次数训练\n",
    "for epoch in range(config.epochs):\n",
    "    model.train()\n",
    "    for iter_id, mini_batch in enumerate(train_dataloader):\n",
    "        input_ids = mini_batch['input_ids']\n",
    "        token_type_ids = mini_batch['token_type_ids']\n",
    "        attention_mask = mini_batch['attention_mask']\n",
    "        labels = mini_batch['labels']\n",
    "        logits = model(input_ids=input_ids, token_type_ids=token_type_ids)\n",
    "        # 计算损失值\n",
    "        loss = loss_fn(logits, labels)\n",
    "        # 计算具体值并校验\n",
    "        probs = paddle.nn.functional.softmax(logits, axis=1)\n",
    "        correct = metric.compute(probs, labels)\n",
    "        metric.update(correct)\n",
    "        acc = metric.accumulate()\n",
    "\n",
    "        # 反向传播\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.clear_grad()\n",
    "        # 打印模型性能\n",
    "        if iter_id%config.print_loss == 0:\n",
    "            print('epoch:{}, iter_id:{}, loss:{}, acc:{}'.format(epoch, iter_id, loss, acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a181a9d-0ce6-4547-a5d5-8357e6a606dd",
   "metadata": {},
   "source": [
    "## 任务7：文本匹配模型（SimCSE模型）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3c589d07-0952-46e0-b460-23f690075e72",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 自定义数据集\n",
    "class SimDataset(Dataset):\n",
    "    def __init__(self, data_dict):\n",
    "        super(SimDataset, self).__init__()\n",
    "        self.input_ids = data_dict['input_ids']\n",
    "        self.token_type_ids = data_dict['token_type_ids']\n",
    "        self.attention_mask = data_dict['attention_mask']\n",
    "        self.labels = data_dict['labels']\n",
    "        self.len = len(self.input_ids)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        data = (self.input_ids[index],\n",
    "                self.token_type_ids[index],\n",
    "                self.attention_mask[index],\n",
    "                self.labels[index])\n",
    "\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.len\n",
    "\n",
    "\n",
    "# 统一处理数据\n",
    "class Collator:\n",
    "\n",
    "    def __init__(self, tokenizer, max_seq_len):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_seq_len = max_seq_len\n",
    "\n",
    "    def pad(self, input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len):\n",
    "        # 初始化填充长度\n",
    "        input_ids = paddle.zeros((len(input_ids_list), max_seq_len), dtype='int64')\n",
    "        token_type_ids = paddle.zeros_like(input_ids)\n",
    "        attention_mask = paddle.zeros_like(input_ids)\n",
    "        # 遍历获取输入\n",
    "        for i in range(len(input_ids_list)):\n",
    "            seq_len = len(input_ids_list[i])\n",
    "\n",
    "            if seq_len < max_seq_len:  # 如果小于最大长度\n",
    "                input_ids[i, :seq_len] = paddle.to_tensor(input_ids_list[i], dtype='int64')\n",
    "                token_type_ids[i, :seq_len] = paddle.to_tensor(token_type_ids_list[i], dtype='int64')\n",
    "                attention_mask[i, :seq_len] = paddle.to_tensor(attention_mask_list[i], dtype='int64')\n",
    "            else:  # 如果大于或等于\n",
    "                # 最后一位加上tokenizer的特殊占位\n",
    "                input_ids[i] = paddle.to_tensor(\n",
    "                    input_ids_list[i][:max_seq_len - 1] + [self.tokenizer.sep_token_id], dtype='int64')\n",
    "                token_type_ids[i] = paddle.to_tensor(\n",
    "                    token_type_ids_list[i][:max_seq_len], dtype='int64')\n",
    "                attention_mask[i] = paddle.to_tensor(\n",
    "                    attention_mask_list[i][:max_seq_len], dtype='int64')\n",
    "        # 格式化输出\n",
    "        labels = paddle.to_tensor([[label] for label in labels_list], dtype='int64')\n",
    "\n",
    "        return input_ids, token_type_ids, attention_mask, labels\n",
    "\n",
    "    def __call__(self, examples):\n",
    "        # 获取数据\n",
    "        input_ids_list, token_type_ids_list, attention_mask_list, labels_list = list(zip(*examples))\n",
    "        # 求句子最大长度\n",
    "        cur_seq_len = max([len(ids) for ids in input_ids_list])  # 当前数据最大长度\n",
    "        max_seq_len = min(cur_seq_len, self.max_seq_len)  # 最大长度\n",
    "        # 填充句子\n",
    "        input_ids, token_type_ids, attention_mask, labels = self.pad(input_ids_list, token_type_ids_list,\n",
    "                                                                     attention_mask_list, labels_list, max_seq_len)\n",
    "        # 返回结果\n",
    "        data = {\n",
    "            'input_ids': input_ids,\n",
    "            'token_type_ids': token_type_ids,\n",
    "            'attention_mask': attention_mask,\n",
    "            'labels': labels,\n",
    "        }\n",
    "        return data\n",
    "\n",
    "\n",
    "# 创建dataloader\n",
    "def create_dataloader(config):\n",
    "    # 读取数据\n",
    "    train = pd.read_csv('data/train.csv')[:2000]\n",
    "    data_dict = {'input_ids': [], 'token_type_ids': [], 'attention_mask': [], 'labels': []}\n",
    "    for i, row in tqdm(train.iterrows()):\n",
    "        seq_a = row[0]\n",
    "        seq_b = row[1]\n",
    "        label = row[2]\n",
    "        inputs_dict = config.tokenizer.encode(seq_a, seq_b, return_special_tokens_mask=True, return_token_type_ids=True,\n",
    "                                              return_attention_mask=True)\n",
    "        data_dict['input_ids'].append(inputs_dict['input_ids'])\n",
    "        data_dict['token_type_ids'].append(inputs_dict['token_type_ids'])\n",
    "        data_dict['attention_mask'].append(inputs_dict['attention_mask'])\n",
    "        data_dict['labels'].append(label)\n",
    "    # 构建dataset\n",
    "    train_dataset = SimDataset(data_dict)\n",
    "\n",
    "    # 构建dataloader\n",
    "    collate_fn = Collator(config.tokenizer, config.max_seq_len)\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, collate_fn=collate_fn, shuffle=True,\n",
    "                                  num_workers=0)\n",
    "    return train_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "39f62dc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "class SimCSE(nn.Layer):\n",
    "\n",
    "    def __init__(self,\n",
    "                 pretrained_model,\n",
    "                 dropout=None,\n",
    "                 margin=0.0,\n",
    "                 scale=20,\n",
    "                 output_emb_size=None):\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "        self.ptm = pretrained_model\n",
    "        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)\n",
    "\n",
    "        self.output_emb_size = output_emb_size\n",
    "        if output_emb_size > 0:\n",
    "            weight_attr = paddle.ParamAttr(\n",
    "                initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))\n",
    "            self.emb_reduce_linear = paddle.nn.Linear(768,\n",
    "                                                      output_emb_size,\n",
    "                                                      weight_attr=weight_attr)\n",
    "\n",
    "        self.margin = margin\n",
    "        self.sacle = scale\n",
    "\n",
    "    def get_pooled_embedding(self,\n",
    "                             input_ids,\n",
    "                             token_type_ids=None,\n",
    "                             attention_mask=None,\n",
    "                             with_pooler=True):\n",
    "\n",
    "        # Note: cls_embedding is poolerd embedding with act tanh\n",
    "        sequence_output, cls_embedding = self.ptm(input_ids, token_type_ids, attention_mask)\n",
    "\n",
    "        if with_pooler == False:\n",
    "            cls_embedding = sequence_output[:, 0, :]\n",
    "\n",
    "        if self.output_emb_size > 0:\n",
    "            cls_embedding = self.emb_reduce_linear(cls_embedding)\n",
    "\n",
    "        cls_embedding = self.dropout(cls_embedding)\n",
    "        cls_embedding = functional.normalize(cls_embedding, p=2, axis=-1)\n",
    "\n",
    "        return cls_embedding\n",
    "\n",
    "    def cosine_sim(self,\n",
    "                   query_input_ids,\n",
    "                   title_input_ids,\n",
    "                   query_token_type_ids=None,\n",
    "                   query_position_ids=None,\n",
    "                   query_attention_mask=None,\n",
    "                   title_token_type_ids=None,\n",
    "                   title_position_ids=None,\n",
    "                   title_attention_mask=None,\n",
    "                   with_pooler=True):\n",
    "\n",
    "        query_cls_embedding = self.get_pooled_embedding(query_input_ids,\n",
    "                                                        query_token_type_ids,\n",
    "                                                        query_position_ids,\n",
    "                                                        query_attention_mask,\n",
    "                                                        with_pooler=with_pooler)\n",
    "\n",
    "        title_cls_embedding = self.get_pooled_embedding(title_input_ids,\n",
    "                                                        title_token_type_ids,\n",
    "                                                        title_position_ids,\n",
    "                                                        title_attention_mask,\n",
    "                                                        with_pooler=with_pooler)\n",
    "\n",
    "        cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,\n",
    "                                axis=-1)\n",
    "        return cosine_sim\n",
    "\n",
    "    def forward(self,\n",
    "                query_input_ids,\n",
    "                title_input_ids,\n",
    "                query_token_type_ids=None,\n",
    "                query_position_ids=None,\n",
    "                query_attention_mask=None,\n",
    "                title_token_type_ids=None,\n",
    "                title_position_ids=None,\n",
    "                title_attention_mask=None):\n",
    "\n",
    "        query_cls_embedding = self.get_pooled_embedding(query_input_ids,\n",
    "                                                        query_token_type_ids,\n",
    "                                                        query_position_ids,\n",
    "                                                        query_attention_mask)\n",
    "\n",
    "        title_cls_embedding = self.get_pooled_embedding(title_input_ids,\n",
    "                                                        title_token_type_ids,\n",
    "                                                        title_position_ids,\n",
    "                                                        title_attention_mask)\n",
    "\n",
    "        cosine_sim = paddle.matmul(query_cls_embedding,\n",
    "                                   title_cls_embedding,\n",
    "                                   transpose_y=True)\n",
    "\n",
    "        margin_diag = paddle.full(shape=[query_cls_embedding.shape[0]],\n",
    "                                  fill_value=self.margin,\n",
    "                                  dtype=paddle.get_default_dtype())\n",
    "\n",
    "        cosine_sim = cosine_sim - paddle.diag(margin_diag)\n",
    "\n",
    "        cosine_sim *= self.sacle\n",
    "\n",
    "        labels = paddle.arange(0, query_cls_embedding.shape[0], dtype='int64')\n",
    "        labels = paddle.reshape(labels, shape=[-1, 1])\n",
    "\n",
    "        loss = functional.cross_entropy(input=cosine_sim, label=labels)\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a120b680",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2000it [00:00, 3278.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0, iter_id:0, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [3.24242687])\n",
      "epoch:0, iter_id:50, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [2.88035679])\n",
      "epoch:0, iter_id:100, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [2.79069757])\n",
      "epoch:1, iter_id:0, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [2.94561958])\n",
      "epoch:1, iter_id:50, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [2.76084089])\n",
      "epoch:1, iter_id:100, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [2.85244846])\n",
      "epoch:2, iter_id:0, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [2.81766462])\n",
      "epoch:2, iter_id:50, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [2.81105328])\n",
      "epoch:2, iter_id:100, loss:Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [2.80228639])\n"
     ]
    }
   ],
   "source": [
    "class Config:\n",
    "    # 数据加载部分\n",
    "    max_seq_len = 15 * 2  # 句子长度\n",
    "    train_data = 'data/train.csv'\n",
    "    val_data = 'data/val.csv'\n",
    "    test_data = 'data/test.csv'\n",
    "    # 模型部分\n",
    "    model_path = 'ernie-3.0-base-zh'  # 本地模型路径\n",
    "    tokenizer = None  # tokenizer对象\n",
    "    # 训练部分\n",
    "    mode = 'train'\n",
    "    learning_rate = 3e-5\n",
    "    batch_size = 16  # batch大小\n",
    "    epochs = 15  # 训练次数\n",
    "    print_loss = 50  # 打印loss次数\n",
    "    num_labels = 2  # 分类数\n",
    "    # simcse参数\n",
    "    dropout = 0.3\n",
    "    output_emb_size = 256\n",
    "config = Config()\n",
    "config.tokenizer = AutoTokenizer.from_pretrained(config.model_path)\n",
    "train_dataloader = create_dataloader(config)\n",
    "# 创建模型\n",
    "model = SimCSE(pretrained_model = AutoModel.from_pretrained(config.model_path),\n",
    "               dropout=config.dropout, output_emb_size=config.output_emb_size)\n",
    "# 定义优化器\n",
    "opt = paddle.optimizer.AdamW(learning_rate=config.learning_rate, parameters=model.parameters())\n",
    "# 定义损失函数\n",
    "loss_fn = nn.loss.CrossEntropyLoss()\n",
    "metric = paddle.metric.Accuracy()\n",
    "# 遍历训练次数训练\n",
    "for epoch in range(config.epochs):\n",
    "    model.train()\n",
    "    for iter_id, mini_batch in enumerate(train_dataloader):\n",
    "        input_ids = mini_batch['input_ids']\n",
    "        token_type_ids = mini_batch['token_type_ids']\n",
    "        attention_mask = mini_batch['attention_mask']\n",
    "        labels = mini_batch['labels']\n",
    "        loss = model(input_ids, attention_mask, token_type_ids)\n",
    "        # 计算损失值\n",
    "        # 反向传播\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.clear_grad()\n",
    "        # 打印模型性能\n",
    "        if iter_id % config.print_loss == 0:\n",
    "            print('epoch:{}, iter_id:{}, loss:{}'.format(epoch, iter_id, loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "771e4078",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
